From c601de4f8061f28c650534fbafdbed60e262e721 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 27 Jan 2026 11:48:41 -0800 Subject: [PATCH 01/30] add pp stage checkers to p2p communicator --- .../core/pipeline_parallel/p2p_communication.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index ac839c21f18..0d2afc73460 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -7,6 +7,7 @@ import torch.distributed as dist from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.utils import nvtx_decorator # Types @@ -162,6 +163,21 @@ def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig): else None ) + @property + def is_pp_first_stage(self) -> bool: + """Return True if pp first stage.""" + return is_pp_first_stage(self.pp_group) + + @property + def is_pp_last_stage(self) -> bool: + """Return True if pp last stage.""" + return is_pp_last_stage(self.pp_group) + + @property + def num_warmup_microbatches(self) -> int: + """Return number of warmup microbatches.""" + return self.pp_group.size() - self.pp_group.rank() - 1 + def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next): """Communicate tensor shapes between stages. Used to communicate tensor shapes before the actual tensor communication happens. From 84ae4f0f1cae303c802650f375d5c71faaf9009b Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 27 Jan 2026 11:49:38 -0800 Subject: [PATCH 02/30] add process group collection wrapper --- megatron/core/process_groups_config.py | 140 ++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index ef8f31ea150..9fa0e080ccc 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field, fields from functools import partial -from typing import List, Optional +from typing import Dict, List, Optional import torch @@ -569,3 +569,141 @@ def setup_process_groups_for_ddp( result['ep_group'] = pg_collection.ep return result + + +@dataclass +class ProcessGroupCollectionWrapper: + """Wrapper for multiple process group collections in multi-module pipelines. + + Used when a rank participates in multiple modules (e.g., colocated encoder + LLM). + The language_model key identifies which module is the language model (used for + CP size extraction and other LLM-specific operations). + + Attributes: + module_collections: Dict mapping module names to ProcessGroupCollection objects + language_model: Key identifying the language model module (None if no LLM on this rank) + + Example: + # Colocated rank with encoder and LLM + wrapper = ProcessGroupCollectionWrapper( + module_collections={ + "encoder": encoder_pg, + "llm": llm_pg + }, + language_model="llm" + ) + + # Rank with dual encoders (no LLM) + wrapper = ProcessGroupCollectionWrapper( + module_collections={ + "encoder_1": encoder_1_pg, + "encoder_2": encoder_2_pg + }, + language_model=None + ) + + # Single module (can also use ProcessGroupCollection directly) + wrapper = ProcessGroupCollectionWrapper( + module_collections={"llm": llm_pg}, + language_model="llm" + ) + + # Usage + cp_size = wrapper.get_language_model_cp_size() + encoder_pg = wrapper["encoder_1"] # Dict-like access + has_llm = wrapper.has_language_model() + """ + + module_collections: Dict[str, ProcessGroupCollection] + language_model: Optional[str] = None + + def __post_init__(self): + if not self.module_collections: + raise ValueError("module_collections dict cannot be empty") + if self.language_model is not None: + if self.language_model not in self.module_collections: + raise ValueError( + f"language_model '{self.language_model}' not found in " + f"module_collections keys: {list(self.module_collections.keys())}" + ) + + def get_language_model_collection(self) -> ProcessGroupCollection: + """Get the language model's process group collection. + + Returns: + ProcessGroupCollection for the language model. + + Raises: + ValueError: If no language model is specified for this wrapper. + """ + if self.language_model is None: + raise ValueError("No language model specified for this wrapper") + return self.module_collections[self.language_model] + + def get_language_model_cp_size(self) -> int: + """Get context parallel size for the language model. + + Returns: + Context parallel size for the language model. + + Raises: + ValueError: If no language model is specified for this wrapper. + """ + return self.get_language_model_collection().cp.size() + + def has_language_model(self) -> bool: + """Check if this rank has a language model. + + Returns: + True if this rank has a language model, False otherwise. + """ + return self.language_model is not None + + def get_module_collection(self, module_name: str) -> ProcessGroupCollection: + """Get process group collection for a specific module. + + Args: + module_name: Name of the module. + + Returns: + ProcessGroupCollection for the specified module. + + Raises: + ValueError: If module_name is not found in collections. + """ + if module_name not in self.module_collections: + raise ValueError( + f"Module '{module_name}' not found in collections. " + f"Available: {list(self.module_collections.keys())}" + ) + return self.module_collections[module_name] + + def __len__(self): + """Return the number of modules in this wrapper.""" + return len(self.module_collections) + + def __getitem__(self, module_name: str): + """Get process group collection for a module using dict-like access.""" + return self.module_collections[module_name] + + def __iter__(self): + """Iterate over all process group collections.""" + return iter(self.module_collections.values()) + + def keys(self): + """Return module names.""" + return self.module_collections.keys() + + def values(self): + """Return process group collections.""" + return self.module_collections.values() + + def items(self): + """Return (module_name, collection) pairs.""" + return self.module_collections.items() + + def __repr__(self): + """Return a concise representation showing modules and their language model status.""" + modules_str = ', '.join(self.module_collections.keys()) + lm_str = f", language_model='{self.language_model}'" if self.language_model else "" + return f"ProcessGroupCollectionWrapper(modules=[{modules_str}]{lm_str})" From 0fa3dd8322a940784ebcfd02ad6af7c4f0d848a3 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 27 Jan 2026 13:59:54 -0800 Subject: [PATCH 03/30] support multimodule pipelining in 1f1b schedule --- megatron/core/pipeline_parallel/schedules.py | 190 +++++++++++-------- megatron/core/pipeline_parallel/utils.py | 86 ++++++++- 2 files changed, 198 insertions(+), 78 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index edca62be375..0868492f6eb 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2,7 +2,7 @@ import contextlib from functools import partial -from typing import Callable, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterator, List, Optional, Union import torch from torch.autograd.variable import Variable @@ -12,14 +12,21 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, +) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( + backward_step_multimodule, is_pp_first_stage, is_pp_last_stage, is_vp_first_stage, is_vp_last_stage, ) -from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.process_groups_config import ( + ProcessGroupCollection, + ProcessGroupCollectionWrapper, +) from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler @@ -157,9 +164,28 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. + + Supports multiple formats: + - torch.Tensor: Deallocates the tensor directly + - List[Tensor]: Recursively deallocates each element + - Dict[str, Tensor]: Recursively deallocates each value (for multi-module pipelines) ''' if (out is None) or (not deallocate_pipeline_outputs): return + + # Handle dict format (multi-module pipelines) + if isinstance(out, dict): + for value in out.values(): + deallocate_output_tensor(value, deallocate_pipeline_outputs) + return + + # Handle list format + if isinstance(out, list): + for item in out: + deallocate_output_tensor(item, deallocate_pipeline_outputs) + return + + # Base case: deallocate tensor assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." out.data = torch.empty((1,), device=out.device, dtype=out.dtype) @@ -443,14 +469,13 @@ def forward_step( return [output_tensor], num_tokens -def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): +def backward_step(input_tensor, output_tensor, output_tensor_grad, config): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage's output tensor. - Returns gradient of loss with respect to input tensor (None if first - stage).""" + Returns gradient of loss with respect to input tensor (None if first stage).""" # NOTE: This code currently can handle at most one skip connection. It # needs to be modified slightly to support arbitrary numbers of skip @@ -649,7 +674,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, config ) # Run computation for last microbatch out of context handler (want to # synchronize gradients). @@ -672,7 +697,7 @@ def forward_backward_no_pipelining( total_num_tokens += num_tokens if not forward_only: - backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config) + backward_step(input_tensor, output_tensor, output_tensor_grad, config) if config.finalize_model_grads_func is not None and not forward_only: # Finalize model grads (perform full grad all-reduce / reduce-scatter for @@ -863,6 +888,7 @@ def forward_backward_pipelining_with_interleaving( ) tp_group = parallel_state.get_tensor_model_parallel_group() cp_group = parallel_state.get_context_parallel_group() + cp_size = cp_group.size() embd_group = parallel_state.get_embedding_group(check_initialized=False) pp_group = parallel_state.get_pipeline_model_parallel_group() pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False) @@ -903,6 +929,7 @@ def forward_backward_pipelining_with_interleaving( assert hasattr(pg_collection, 'dp_cp'), "pg_collection must have a dp_cp_group" tp_group = pg_collection.tp cp_group = pg_collection.cp + cp_size = cp_group.size() else: raise ValueError( "Invalid combination of p2p_communicator, pg_collection" @@ -1234,7 +1261,7 @@ def forward_step_helper(virtual_microbatch_id, checkpoint_activations_microbatch input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step( @@ -1300,7 +1327,7 @@ def backward_step_helper(virtual_microbatch_id): ) input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor, output_tensor, output_tensor_grad, config ) backward_step_helper_postprocess(virtual_microbatch_id) @@ -1909,7 +1936,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): # If defer_embedding_wgrad_compute is enabled we need to do the # weight gradient GEMM's here. finish_embedding_wgrad_compute( - config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group + config, embedding_module, p2p_communicator.is_pp_last_stage, tp_group ) # Finalize model grads (perform full grad all-reduce / reduce-scatter for @@ -1949,16 +1976,22 @@ def get_tensor_shapes( micro_batch_size: int, decoder_seq_length: int, config, - tp_group: torch.distributed.ProcessGroup, - cp_group: torch.distributed.ProcessGroup, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ): - """ - Determine right tensor sizes (based on position of rank with respect to split rank) and - model size. - """ + """Determine tensor shapes for pipeline communication. + Returns [()] for variable_seq_lengths mode (shapes exchanged dynamically), + or computed shapes for fixed sequence length mode. + """ tensor_shapes = [] - # Use decoder_seq_length if provided, otherwise use seq_length + + if config.variable_seq_lengths: + # Shapes exchanged dynamically during P2P communication + tensor_shapes.append(()) + return tensor_shapes + + # Fixed sequence lengths - compute shape effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length effective_seq_length = effective_seq_length // cp_group.size() @@ -1983,7 +2016,7 @@ def forward_backward_pipelining_without_interleaving( first_val_step: Optional[bool] = None, adjust_tensor_shapes_fn: Optional[Callable] = None, p2p_communicator: Optional[P2PCommunicator] = None, - pg_collection: Optional[ProcessGroupCollection] = None, + pg_collection: Optional[Union[ProcessGroupCollection, ProcessGroupCollectionWrapper]] = None, force_all_reduce: Optional[bool] = False, ): """Run non-interleaved 1F1B schedule, with communication between pipeline @@ -2006,12 +2039,16 @@ def forward_backward_pipelining_without_interleaving( "Non-interleaved pipeline parallelism does not support overlapping p2p communication" ) + tp_group, cp_group, cp_size = None, None, None + if p2p_communicator is None and pg_collection is None: + # Default: single-module with parallel_state groups p2p_communicator = P2PCommunicator( pp_group=parallel_state.get_pipeline_model_parallel_group(), config=config ) tp_group = parallel_state.get_tensor_model_parallel_group() cp_group = parallel_state.get_context_parallel_group() + cp_size = cp_group.size() embd_group = parallel_state.get_embedding_group(check_initialized=False) pos_emb_group = parallel_state.get_position_embedding_group(check_initialized=False) pp_group = parallel_state.get_pipeline_model_parallel_group() @@ -2025,43 +2062,41 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + elif p2p_communicator is not None and pg_collection is not None: - model_type = get_model_type(model) - assert model_type != ModelType.encoder_and_decoder, ( - "encoder PP stages not yet supported when passing custom process groups. " - "support coming soon!" - ) - assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" - assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" - assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" - assert hasattr(pg_collection, 'embd'), ( - "pg_collection must have a embd. In previous version, it is used default " - "`parallel_state.default_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_embedding_group()` " - "If you don't need embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pos_embd'), ( - "pg_collection must have a pos_embd. In previous version, it is used default " - "`parallel_state.default_position_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_position_embedding_group()` " - "If you don't need pos_embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pp'), "pg_collection must have pp_group" - assert hasattr(pg_collection, 'dp_cp'), "pg_collection must have dp_cp_group" - tp_group = pg_collection.tp - cp_group = pg_collection.cp + # Custom process groups provided + + if isinstance(pg_collection, ProcessGroupCollectionWrapper): + # Multi-module: use language model's CP size for loss scaling + if not config.variable_seq_lengths: + raise ValueError( + "config.variable_seq_lengths=True required for multi-module pipelines" + ) + cp_size = pg_collection.get_language_model_cp_size() + # tp_group and cp_group stay None (variable_seq_lengths mode) + + elif isinstance(pg_collection, ProcessGroupCollection): + # Single-module: extract tp/cp groups and cp_size + # Note: finalize_model_grads validates other fields (embd, pos_embd, pp, dp_cp) + assert hasattr(pg_collection, 'tp'), "pg_collection must have tp" + assert hasattr(pg_collection, 'cp'), "pg_collection must have cp" + tp_group = pg_collection.tp + cp_group = pg_collection.cp + cp_size = cp_group.size() + + else: + raise TypeError( + f"pg_collection must be ProcessGroupCollection or ProcessGroupCollectionWrapper, " + f"got {type(pg_collection)}" + ) + else: - raise ValueError( - "Invalid combination of p2p_communicator, pg_collection " - "provide none or provide all the process groups" - ) + raise ValueError("Provide both p2p_communicator and pg_collection, or neither") # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: embedding_module = clear_embedding_activation_buffer( - config, model, is_pp_last_stage(p2p_communicator.pp_group) + config, model, p2p_communicator.is_pp_last_stage ) if config.timers is not None: @@ -2090,9 +2125,7 @@ def enable_grad_sync(): disable_grad_sync() # Compute number of warmup microbatches. - num_warmup_microbatches = ( - p2p_communicator.pp_group.size() - p2p_communicator.pp_group.rank() - 1 - ) + num_warmup_microbatches = p2p_communicator.num_warmup_microbatches num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches @@ -2108,9 +2141,12 @@ def enable_grad_sync(): if config.num_microbatches_with_partial_activation_checkpoints is not None: max_outstanding_backprops = num_warmup_microbatches + 1 - model_type = get_model_type(model) + # Select backward function based on whether multi-module or single-module + is_multimodule = isinstance(pg_collection, ProcessGroupCollectionWrapper) or isinstance( + p2p_communicator, MultiModulePipelineCommunicator + ) + backward_func = backward_step_multimodule if is_multimodule else backward_step - rank = p2p_communicator.pp_group.rank() recv_tensor_shapes = get_tensor_shapes( seq_length=seq_length, micro_batch_size=micro_batch_size, @@ -2154,7 +2190,7 @@ def enable_grad_sync(): checkpoint_activations_microbatch = None input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) output_tensor, num_tokens = forward_step( forward_step_func, @@ -2164,27 +2200,27 @@ def enable_grad_sync(): input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), current_microbatch=i, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) - p2p_communicator.send_forward(output_tensor, is_pp_last_stage(p2p_communicator.pp_group)) + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) total_num_tokens += num_tokens if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run 1F1B in steady state. @@ -2207,34 +2243,34 @@ def enable_grad_sync(): input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step( first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0) ), current_microbatch=i + num_warmup_microbatches, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) total_num_tokens += num_tokens if forward_only: p2p_communicator.send_forward( - output_tensor, is_pp_last_stage(p2p_communicator.pp_group) + output_tensor, p2p_communicator.is_pp_last_stage ) if not last_iteration: input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) else: output_tensor_grad = p2p_communicator.send_forward_recv_backward( - output_tensor, send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + output_tensor, send_tensor_shapes, p2p_communicator.is_pp_last_stage ) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. @@ -2244,23 +2280,23 @@ def enable_grad_sync(): # Enable grad sync for the last microbatch in the batch if the full # backward pass completes in the 1F1B stage. if num_warmup_microbatches == 0 and last_iteration: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor_grad = backward_func( + input_tensor, output_tensor, output_tensor_grad, config ) if last_iteration: input_tensor = None p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) + input_tensor_grad, p2p_communicator.is_pp_first_stage ) else: input_tensor = p2p_communicator.send_backward_recv_forward( input_tensor_grad, recv_tensor_shapes, - is_pp_first_stage(p2p_communicator.pp_group), + p2p_communicator.is_pp_first_stage, ) # Run cooldown backward passes. @@ -2273,22 +2309,22 @@ def enable_grad_sync(): # pipeline stages do grad reduction during pipeline # bubble. if i == num_warmup_microbatches - 1: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = p2p_communicator.recv_backward( - send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + send_tensor_shapes, p2p_communicator.is_pp_last_stage ) - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config + input_tensor_grad = backward_func( + input_tensor, output_tensor, output_tensor_grad, config ) p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) + input_tensor_grad, p2p_communicator.is_pp_first_stage ) # Launch any remaining grad reductions. @@ -2302,7 +2338,7 @@ def enable_grad_sync(): # If defer_embedding_wgrad_compute is enabled we need to do the # weight gradient GEMM's here. finish_embedding_wgrad_compute( - config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group + config, embedding_module, p2p_communicator.is_pp_last_stage, tp_group ) # Finalize model grads (perform full grad all-reduce / reduce-scatter for diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 03c5f01f443..0d7c21a6456 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Callable, Optional +from typing import Callable, Dict, Optional, Union import torch from torch.autograd import Variable @@ -349,3 +349,87 @@ def get_comm_stream(): """Get the stream for communication""" global _COMM_STREAM return _COMM_STREAM + + +def backward_step_multimodule( + input_tensor: Dict[str, torch.Tensor], + output_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], + output_tensor_grad: Optional[Dict[str, torch.Tensor]], + config, +) -> Dict[str, torch.Tensor]: + """Backward step for multi-module pipelines. + + In multi-module pipelines, tensors are organized as dictionaries with + module names as keys. Each module's backward pass is performed independently. + + This function should be called explicitly for multi-module pipelines. + For single-module pipelines, use backward_step() instead. + + Args: + input_tensor: Dict mapping module names to input tensors + output_tensor: Dict mapping module names to output tensors, or scalar loss (last stage) + output_tensor_grad: Dict mapping module names to output grads, or None (last stage) + config: Model parallel configuration + + Returns: + Dict mapping module names to input tensor gradients + + Note: + - Assumes each module operates independently (no cross-module gradients in forward) + - Each module should have sequential pipeline stages (no cross-stage skip connections) + - Encoder-decoder models with skip connections (e.g., T5) are not yet supported as LLM. + - Last stage: Scalar loss requires single-module; multi-module should return dict of losses + """ + # Import locally to avoid circular dependency + from megatron.core.pipeline_parallel.schedules import custom_backward + + # Retain gradients on all input tensors + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is not None: + tensor.retain_grad() + + # Last stage: output_tensor is a scalar loss, wrap in dict for uniform handling + # Assumes last stage only has one module (LLM) + if not isinstance(output_tensor, dict): + all_keys = list(input_tensor.keys()) + main_module_key = all_keys[0] + output_tensor = {main_module_key: output_tensor} + + # Handle output_tensor_grad: None (last stage) or dict (intermediate stages) + if not output_tensor_grad: + # Last stage: no gradient from next stage + output_tensor_grad = {key: None for key in output_tensor.keys()} + + # Apply grad scaling if needed (for last stage only) + for module_name in output_tensor.keys(): + if output_tensor_grad[module_name] is None and config.grad_scale_func is not None: + output_tensor[module_name] = config.grad_scale_func(output_tensor[module_name]) + + # Perform backward pass for each module + for module_name in output_tensor.keys(): + output_tensor_module = output_tensor[module_name] + output_tensor_grad_module = output_tensor_grad[module_name] + + # Skip backward if tensor doesn't require gradients + # (e.g., in VLM models, some batches may not have images) + if output_tensor_module is not None and output_tensor_module.requires_grad: + if config.deallocate_pipeline_outputs: + custom_backward(output_tensor_module, output_tensor_grad_module) + else: + torch.autograd.backward( + output_tensor_module, grad_tensors=output_tensor_grad_module + ) + + # Collect gradients for input tensors + input_tensor_grad = {} + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is None: + input_tensor_grad[module_name] = None + else: + input_tensor_grad[module_name] = tensor.grad + + return input_tensor_grad From b22f638fa2fa382adf914b6251cc87a9fb2fa152 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 12:57:34 -0800 Subject: [PATCH 04/30] fix dim mapping in torch cat bridge comm --- megatron/core/pipeline_parallel/bridge_communicator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index f1e74a2f16d..df206f6f7e6 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -494,7 +494,7 @@ def recv_backward(self) -> torch.Tensor: received_gradients_list.append(grad_tensor) # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=0) + aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -615,7 +615,7 @@ def send_forward_recv_backward( req.wait() # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=0) + aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) logging.debug( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -737,9 +737,9 @@ def send_backward_recv_forward( req.wait() # Concatenate received activations - aggregated_activation = torch.cat(received_activations_list, dim=0) + aggregated_activation = torch.cat(received_activations_list, dim=self.dim_mapping['b']) logging.debug( - f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " + f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" ) From 3badf57bacb2717d549d811e4345081c6edcc552 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 12:58:08 -0800 Subject: [PATCH 05/30] handle 3d 2d tensor conversion in multimodule comm --- .../multimodule_communicator.py | 98 +++++++++++-------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index 1e8da3468e2..daa5d69f935 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -45,6 +45,32 @@ class RankModuleInfo: is_terminal_stage: Optional[bool] = True +def _ensure_3d_tensor(tensor): + """Ensure tensor is 3D for P2P/bridge communication. + + P2P and bridge communicators expect 3D tensors. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_ensure_3d_tensor(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 2: + return tensor.unsqueeze(-1) + return tensor + + +def _restore_tensor_shape(tensor): + """Restore original tensor shape after P2P/bridge communication. + + Remove the extra dimension added by _ensure_3d_tensor if it was singleton. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_restore_tensor_shape(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1: + return tensor.squeeze(-1) + return tensor + + class MultiModulePipelineCommunicator: """Communicator for a multi-module pipeline.""" @@ -266,12 +292,14 @@ def recv_forward( # If first stage, and has incoming modules, receive forward activation # from incoming modules. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - input_dict[bridge_comm.src_module_name] = bridge_comm.recv_forward() + received_tensor = bridge_comm.recv_forward() + input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(received_tensor) else: # If not first stage, receive forward activation tensor from P2P communicator. - input_dict[module_name] = rank_module_info.p2p_communicator.recv_forward( + received_tensor = rank_module_info.p2p_communicator.recv_forward( tensor_shapes=tensor_shape, is_first_stage=False ) + input_dict[module_name] = _restore_tensor_shape(received_tensor) return input_dict def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool = False): @@ -280,20 +308,18 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool Args: output_dict: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_forward] output_dict keys: {output_dict.keys()}, is_last_stage: {is_last_stage}" - ) for module_name, rank_module_info in self.rank_module_map.items(): if rank_module_info.pp_rank == rank_module_info.pp_size - 1: # If last stage, and has outgoing modules, send forward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - bridge_comm.send_forward(output_dict[module_name]) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + bridge_comm.send_forward(tensor_to_send) else: # If not last stage, send forward activation by using P2P communicator. + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) rank_module_info.p2p_communicator.send_forward( - output_dict[module_name], is_last_stage=False + tensor_to_send, is_last_stage=False ) def send_forward_recv_backward( @@ -311,28 +337,23 @@ def send_forward_recv_backward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_forward_recv_backward] output_dict keys: {output_dict.keys()}, " - f"tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}" - ) grad_dict = {} for module_name, rank_module_info in self.rank_module_map.items(): if rank_module_info.pp_rank == rank_module_info.pp_size - 1: # If last stage, and has outgoing modules, send forward activation and # receive backward gradient by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.send_forward_recv_backward( - output_dict[module_name] - ) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad = bridge_comm.send_forward_recv_backward(tensor_to_send) + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) else: # If not last stage, send forward activation and receive backward gradient # by using P2P communicator. - grad_dict[module_name] = ( - rank_module_info.p2p_communicator.send_forward_recv_backward( - output_dict[module_name], tensor_shapes=tensor_shape, is_last_stage=False - ) + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad = rank_module_info.p2p_communicator.send_forward_recv_backward( + tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False ) + grad_dict[module_name] = _restore_tensor_shape(grad) return grad_dict def send_backward_recv_forward( @@ -350,30 +371,23 @@ def send_backward_recv_forward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_backward_recv_forward] grad_dict keys: {grad_dict.keys()}, " - f"tensor_shape: {tensor_shape}, is_first_stage: {is_first_stage}" - ) input_dict = {} for module_name, rank_module_info in self.rank_module_map.items(): if rank_module_info.pp_rank == 0: for bridge_comm in rank_module_info.bridge_comms_as_dest_module: # If first stage, and has incoming modules, send backward gradient and # receive forward activation by using bridge communicator. - input_dict[bridge_comm.src_module_name] = ( - bridge_comm.send_backward_recv_forward( - grad_dict[bridge_comm.src_module_name] - ) - ) + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) + input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(received_tensor) else: # If not first stage, send backward gradient and receive forward activation # by using P2P communicator. - input_dict[module_name] = ( - rank_module_info.p2p_communicator.send_backward_recv_forward( - grad_dict[module_name], tensor_shapes=tensor_shape, is_first_stage=False - ) + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + received_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward( + grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False ) + input_dict[module_name] = _restore_tensor_shape(received_tensor) return input_dict def recv_backward( @@ -397,12 +411,14 @@ def recv_backward( # If last stage, and has incoming modules, receive backward gradient # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.recv_backward() + grad = bridge_comm.recv_backward() + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) else: # If not last stage, receive backward gradient by using P2P communicator. - grad_dict[module_name] = rank_module_info.p2p_communicator.recv_backward( + grad = rank_module_info.p2p_communicator.recv_backward( tensor_shapes=tensor_shape, is_last_stage=False ) + grad_dict[module_name] = _restore_tensor_shape(grad) return grad_dict def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False): @@ -411,20 +427,18 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool Args: grad_dict: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_backward] grad_dict keys: {grad_dict.keys()}, is_first_stage: {is_first_stage}" - ) for module_name, rank_module_info in self.rank_module_map.items(): if rank_module_info.pp_rank == 0: # If first stage, and has incoming modules, send backward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name]) + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + bridge_comm.send_backward(grad_to_send) else: # If not first stage, send backward activation by using P2P communicator. + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) rank_module_info.p2p_communicator.send_backward( - grad_dict[module_name], is_first_stage=False + grad_to_send, is_first_stage=False ) @staticmethod From 20d03f5dbd5b9d5aa3a581404fd3af633a6affd7 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 13:07:13 -0800 Subject: [PATCH 06/30] add unit tests for multimodule pipeline schedules --- .../test_multimodule_schedules.py | 492 ++++++++++++++++++ 1 file changed, 492 insertions(+) create mode 100644 tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py new file mode 100644 index 00000000000..9df3037da70 --- /dev/null +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py @@ -0,0 +1,492 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for multimodule pipeline schedules with heterogeneous parallelism.""" + +from contextlib import contextmanager +from typing import Dict, Optional + +import pytest +import torch +import torch.distributed as dist +from packaging import version + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core import ModelParallelConfig +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, +) +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def create_hypercomm_grid(offset=0, tp=1, pp=1, dp=1): + """Create a HyperCommGrid with specified parallelism.""" + grid = HyperCommGrid( + shape=[tp, 1, pp, dp, 1], # [tp, cp, pp, dp, ep] + dim_names=["tp", "cp", "pp", "dp", "ep"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["ep"]) + return grid + + +def get_pg_collection(grid): + """Get ProcessGroupCollection from grid.""" + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + pg_collection.dp = grid.get_pg("dp") + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + return pg_collection + + +def add_embedding_groups(pg_collection): + """Add embedding groups to process group collection.""" + if not pg_collection.pp: + return pg_collection + + pp_ranks = sorted(dist.get_process_group_ranks(pg_collection.pp)) + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + + pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) + embd_pg = dist.new_group(ranks=embd_ranks) + + # Always set pos_embd and embd (to group or None) + pg_collection.pos_embd = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + pg_collection.embd = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + + return pg_collection + + +def create_transformer_block(hidden_size, pg_collection, dtype=torch.bfloat16): + """Create a transformer block for testing.""" + torch.manual_seed(12345) + model_parallel_cuda_manual_seed( + 123, + tp_rank=pg_collection.tp.rank(), + ep_rank=pg_collection.ep.rank() if hasattr(pg_collection, 'ep') else 0, + etp_rank=dist.get_rank(), + ) + + config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + use_cpu_initialization=True, + attention_dropout=0.0, + hidden_dropout=0.0, + bf16=(dtype == torch.bfloat16), + ) + + block = TransformerBlock( + config, + get_gpt_layer_with_transformer_engine_spec(), + pg_collection=pg_collection, + ).cuda().to(dtype) + + with torch.no_grad(): + for mod in block.modules(): + if hasattr(mod, "bias") and mod.bias is not None: + mod.bias.zero_() + + ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) + block = DistributedDataParallel( + config=block.config, ddp_config=ddp_config, module=block, pg_collection=pg_collection + ) + block.pre_process = False + block.post_process = False + block.share_embeddings_and_output_weights = False + return block + + +def create_module_with_grid(tp, pp, dp, grid_offset, hidden_size): + """Create a module (transformer block) with its grid.""" + rank = dist.get_rank() + grid = create_hypercomm_grid(offset=grid_offset, tp=tp, pp=pp, dp=dp) + + if grid.rank_offset <= rank < grid.rank_offset + grid.size: + pg_collection = add_embedding_groups(get_pg_collection(grid)) + module = create_transformer_block(hidden_size, pg_collection) + else: + module = None + + return module, grid + + +# ============================================================================ +# Model Wrapper +# ============================================================================ + + +class MultiModuleModel(torch.nn.Module): + """Wrapper for testing multimodule schedules with multiple encoders + LLM.""" + + def __init__(self, encoder_configs, llm_config, hidden_size): + """ + Args: + encoder_configs: List of dicts with keys: tp, pp, dp, grid_offset, name + llm_config: Dict with keys: tp, pp, dp, grid_offset + hidden_size: Hidden dimension size + """ + super().__init__() + self.hidden_size = hidden_size + self.rank = dist.get_rank() + + # Create encoders + self.encoders = {} + self.encoder_grids = {} + for enc_cfg in encoder_configs: + name = enc_cfg['name'] + module, grid = create_module_with_grid( + enc_cfg['tp'], enc_cfg['pp'], enc_cfg['dp'], + enc_cfg['grid_offset'], hidden_size + ) + self.encoders[name] = module + self.encoder_grids[name] = grid + + # Create LLM + self.llm, self.llm_grid = create_module_with_grid( + llm_config['tp'], llm_config['pp'], llm_config['dp'], + llm_config['grid_offset'], hidden_size + ) + + # Track all modules for gradient sync + self.modules_and_grids = [] + for name, module in self.encoders.items(): + self.modules_and_grids.append((module, self.encoder_grids[name])) + self.modules_and_grids.append((self.llm, self.llm_grid)) + + # Input tensors for pipeline stages + self.input_tensors = {name: None for name in self.encoders.keys()} + self.input_tensors['llm'] = None + + def is_rank_in_grid(self, grid): + """Check if current rank is in grid.""" + return grid.rank_offset <= self.rank < grid.rank_offset + grid.size + + @contextmanager + def no_sync(self): + """No-sync context for all active modules.""" + contexts = [] + for module, grid in self.modules_and_grids: + if module is not None and self.is_rank_in_grid(grid): + contexts.append(module.no_sync()) + + for ctx in contexts: + ctx.__enter__() + try: + yield + finally: + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + @property + def ddp_config(self): + """Get DDP config from first active module.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_rank_in_grid(grid): + return module.ddp_config + raise AttributeError(f"No active modules on rank {self.rank}") + + def finalize_model_grads(self, *args, **kwargs): + """Finalize gradients for all active modules.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_rank_in_grid(grid): + pg_collection = add_embedding_groups(get_pg_collection(grid)) + finalize_model_grads([module], num_tokens=None, pg_collection=pg_collection) + + def set_input_tensor(self, input_tensor): + """Set input tensors from previous pipeline stage.""" + if not input_tensor or not input_tensor[0]: + return + + tensor_dict = input_tensor[0] + + # Set encoder inputs + for name in self.encoders.keys(): + if name in tensor_dict: + self.input_tensors[name] = ( + tensor_dict[name][0] + if isinstance(tensor_dict[name], list) + else tensor_dict[name] + ) + + # Set LLM input (from either encoder outputs or previous LLM stage) + # Only do this if we're on the LLM grid + if self.is_rank_in_grid(self.llm_grid): + if 'llm' in tensor_dict: + self.input_tensors['llm'] = ( + tensor_dict['llm'][0] + if isinstance(tensor_dict['llm'], list) + else tensor_dict['llm'] + ) + elif len(self.encoders) > 0: + # Concatenate encoder outputs for LLM input (received via bridge) + encoder_outputs = [] + for name in self.encoders.keys(): + if name in tensor_dict: + tensor = tensor_dict[name] + # Extract tensor from list if needed (P2P sends as list) + if isinstance(tensor, list): + tensor = tensor[0] + encoder_outputs.append(tensor) + if encoder_outputs: + self.input_tensors['llm'] = ( + torch.cat(encoder_outputs, dim=0) + if len(encoder_outputs) > 1 + else encoder_outputs[0] + ) + + def forward(self, hidden_states): + """Forward pass through active modules.""" + output_dict = {} + + # Forward through encoders + for name, encoder in self.encoders.items(): + if encoder is not None and self.is_rank_in_grid(self.encoder_grids[name]): + pp_group = self.encoder_grids[name].get_pg("pp") + input_tensor = ( + hidden_states if is_pp_first_stage(pp_group) + else self.input_tensors[name] + ) + output_dict[name] = encoder(input_tensor, attention_mask=None) + + # Forward through LLM + if self.llm is not None and self.is_rank_in_grid(self.llm_grid): + output_dict['llm'] = self.llm(self.input_tensors['llm'], attention_mask=None) + + return output_dict + + +# ============================================================================ +# Data Iterator +# ============================================================================ + + +class DataIterator: + """Simple data iterator for testing.""" + + def __init__(self, hidden_size, seq_length, micro_batch_size): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + + def __iter__(self): + return self + + def __next__(self): + return torch.randn( + self.seq_length, self.micro_batch_size, self.hidden_size, + device='cuda', dtype=torch.bfloat16 + ) + + +# ============================================================================ +# Test Runner +# ============================================================================ + + +def run_multimodule_schedule_test( + encoder_configs, llm_config, hidden_size, seq_length, + micro_batch_size, num_microbatches +): + """Run multimodule schedule test with given configuration. + + Args: + encoder_configs: List of encoder configs + llm_config: LLM config dict + hidden_size: Hidden dimension + seq_length: Sequence length + micro_batch_size: Micro batch size + num_microbatches: Number of microbatches + """ + # Create model + model = MultiModuleModel(encoder_configs, llm_config, hidden_size) + model.model_type = 'unit-test' + + # Build module_to_grid_map and topology + module_to_grid_map = {name: grid for name, grid in model.encoder_grids.items()} + module_to_grid_map['llm'] = model.llm_grid + + topology = {name: ['llm'] for name in model.encoders.keys()} + topology['llm'] = [] + + # Configure + config = ModelParallelConfig(pipeline_dtype=torch.bfloat16) + config.variable_seq_lengths = True + config.calculate_per_token_loss = False + config.fine_grained_activation_offloading = False + config.qk_layernorm = False + config.sequence_parallel = False + config.moe_router_enable_expert_bias = False + config.moe_router_load_balancing_type = "aux_loss" + config.no_sync_func = model.no_sync + config.finalize_model_grads_func = model.finalize_model_grads + config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) else loss + ) + config.hidden_size = hidden_size + model.config = config + + # Create communicator + communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + # Create data iterator (only on first encoder's first stage) + data_iterator = None + first_encoder_name = encoder_configs[0]['name'] + first_encoder_grid = model.encoder_grids[first_encoder_name] + if model.is_rank_in_grid(first_encoder_grid): + if is_pp_first_stage(first_encoder_grid.get_pg("pp")): + data_iterator = DataIterator(hidden_size, seq_length, micro_batch_size) + + # Get process group collection for current rank + rank = dist.get_rank() + pg_collection = None + for name, grid in model.encoder_grids.items(): + if grid.rank_offset <= rank < grid.rank_offset + grid.size: + pg_collection = add_embedding_groups(get_pg_collection(grid)) + break + if pg_collection is None and model.llm_grid.rank_offset <= rank < model.llm_grid.rank_offset + model.llm_grid.size: + pg_collection = add_embedding_groups(get_pg_collection(model.llm_grid)) + + # Define step function + def step_func(data_iterator, model): + def loss_func(output_tensor_dict: Dict[str, torch.Tensor]): + assert 'llm' in output_tensor_dict, f"Expected 'llm' in output" + loss = output_tensor_dict['llm'].sum() + return loss, {'loss_reduced': loss} + + input_tensor = next(data_iterator) if data_iterator is not None else None + model_output = model(input_tensor) + return model_output, loss_func + + # Run schedule + losses = schedule.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=data_iterator, + model=[model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=pg_collection, + ) + + # Verify results on last LLM stage + if model.is_rank_in_grid(model.llm_grid): + if is_pp_last_stage(model.llm_grid.get_pg("pp")): + assert len(losses) > 0, "Expected losses on last LLM stage" + + return losses + + +# ============================================================================ +# Tests +# ============================================================================ + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMultimoduleSchedules: + """Test multimodule pipeline schedules.""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def test_single_encoder_2gpu(self): + """Test single encoder + LLM on 2 GPUs (no PP).""" + if self.world_size != 2: + pytest.skip(f"Requires 2 GPUs, got {self.world_size}") + + encoder_configs = [{'name': 'encoder', 'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 0}] + llm_config = {'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 1} + + run_multimodule_schedule_test( + encoder_configs, llm_config, + hidden_size=512, seq_length=64, micro_batch_size=2, num_microbatches=4 + ) + + def test_dual_encoder_2gpu(self): + """Test dual encoder + LLM on 2 GPUs (both encoders on rank 0).""" + if self.world_size != 2: + pytest.skip(f"Requires 2 GPUs, got {self.world_size}") + + encoder_configs = [ + {'name': 'encoder_1', 'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 0}, + {'name': 'encoder_2', 'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 0}, + ] + llm_config = {'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 1} + + run_multimodule_schedule_test( + encoder_configs, llm_config, + hidden_size=512, seq_length=64, micro_batch_size=2, num_microbatches=4 + ) + + def test_single_encoder_8gpu(self): + """Test single encoder + LLM on 8 GPUs (TP=2, PP=2 each).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + encoder_configs = [{'name': 'encoder', 'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 0}] + llm_config = {'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 4} + + run_multimodule_schedule_test( + encoder_configs, llm_config, + hidden_size=1024, seq_length=512, micro_batch_size=4, num_microbatches=16 + ) + + def test_dual_encoder_8gpu(self): + """Test dual encoder + LLM on 8 GPUs (TP=2, PP=2 for each).""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + encoder_configs = [ + {'name': 'encoder_1', 'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 0}, + {'name': 'encoder_2', 'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 0}, + ] + llm_config = {'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 4} + + run_multimodule_schedule_test( + encoder_configs, llm_config, + hidden_size=1024, seq_length=512, micro_batch_size=4, num_microbatches=16 + ) From a6606d89270b9fb8458969b7ddf2c8df7ea4b4ca Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 14:48:30 -0800 Subject: [PATCH 07/30] refactor multimodule pg collection and backward step - Rename ProcessGroupCollectionWrapper to MultiModuleProcessGroupCollection - Rename language_model field to language_model_module_name for clarity - Add language_model_module_name param to backward_step_multimodule - Use functools.partial to bind param, keeping signature consistent - Add type hints to _ensure_3d_tensor and _restore_tensor_shape - Move is_multimodule check earlier for validation and backward selection --- .../multimodule_communicator.py | 24 ++++++++- megatron/core/pipeline_parallel/schedules.py | 24 ++++++--- megatron/core/pipeline_parallel/utils.py | 13 ++--- megatron/core/process_groups_config.py | 52 +++++++++---------- 4 files changed, 71 insertions(+), 42 deletions(-) diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index daa5d69f935..7f79bc59de4 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -45,12 +45,22 @@ class RankModuleInfo: is_terminal_stage: Optional[bool] = True -def _ensure_3d_tensor(tensor): +def _ensure_3d_tensor( + tensor: Union[torch.Tensor, List[torch.Tensor], None] +) -> Union[torch.Tensor, List[torch.Tensor], None]: """Ensure tensor is 3D for P2P/bridge communication. P2P and bridge communicators expect 3D tensors. Handles both single tensors and lists of tensors (for VPP). + + Args: + tensor: Input tensor (2D or 3D), list of tensors, or None. + + Returns: + 3D tensor (with singleton last dim if input was 2D), list of 3D tensors, or None. """ + if tensor is None: + return None if isinstance(tensor, list): return [_ensure_3d_tensor(t) for t in tensor] if isinstance(tensor, torch.Tensor) and tensor.ndim == 2: @@ -58,12 +68,22 @@ def _ensure_3d_tensor(tensor): return tensor -def _restore_tensor_shape(tensor): +def _restore_tensor_shape( + tensor: Union[torch.Tensor, List[torch.Tensor], None] +) -> Union[torch.Tensor, List[torch.Tensor], None]: """Restore original tensor shape after P2P/bridge communication. Remove the extra dimension added by _ensure_3d_tensor if it was singleton. Handles both single tensors and lists of tensors (for VPP). + + Args: + tensor: Input tensor (3D with singleton last dim), list of tensors, or None. + + Returns: + 2D tensor (if last dim was singleton), list of tensors, or None. """ + if tensor is None: + return None if isinstance(tensor, list): return [_restore_tensor_shape(t) for t in tensor] if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1: diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 0868492f6eb..fc9ed2f9db4 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -25,7 +25,7 @@ ) from megatron.core.process_groups_config import ( ProcessGroupCollection, - ProcessGroupCollectionWrapper, + MultiModuleProcessGroupCollection, ) from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.enums import CudaGraphScope @@ -2016,7 +2016,7 @@ def forward_backward_pipelining_without_interleaving( first_val_step: Optional[bool] = None, adjust_tensor_shapes_fn: Optional[Callable] = None, p2p_communicator: Optional[P2PCommunicator] = None, - pg_collection: Optional[Union[ProcessGroupCollection, ProcessGroupCollectionWrapper]] = None, + pg_collection: Optional[Union[ProcessGroupCollection, MultiModuleProcessGroupCollection]] = None, force_all_reduce: Optional[bool] = False, ): """Run non-interleaved 1F1B schedule, with communication between pipeline @@ -2041,6 +2041,11 @@ def forward_backward_pipelining_without_interleaving( tp_group, cp_group, cp_size = None, None, None + # Determine if this is a multi-module pipeline (used for validation and backward function selection) + is_multimodule = isinstance(pg_collection, MultiModuleProcessGroupCollection) or isinstance( + p2p_communicator, MultiModulePipelineCommunicator + ) + if p2p_communicator is None and pg_collection is None: # Default: single-module with parallel_state groups p2p_communicator = P2PCommunicator( @@ -2066,7 +2071,7 @@ def forward_backward_pipelining_without_interleaving( elif p2p_communicator is not None and pg_collection is not None: # Custom process groups provided - if isinstance(pg_collection, ProcessGroupCollectionWrapper): + if is_multimodule: # Multi-module: use language model's CP size for loss scaling if not config.variable_seq_lengths: raise ValueError( @@ -2086,7 +2091,7 @@ def forward_backward_pipelining_without_interleaving( else: raise TypeError( - f"pg_collection must be ProcessGroupCollection or ProcessGroupCollectionWrapper, " + f"pg_collection must be ProcessGroupCollection or MultiModuleProcessGroupCollection, " f"got {type(pg_collection)}" ) @@ -2142,10 +2147,13 @@ def enable_grad_sync(): max_outstanding_backprops = num_warmup_microbatches + 1 # Select backward function based on whether multi-module or single-module - is_multimodule = isinstance(pg_collection, ProcessGroupCollectionWrapper) or isinstance( - p2p_communicator, MultiModulePipelineCommunicator - ) - backward_func = backward_step_multimodule if is_multimodule else backward_step + if is_multimodule: + backward_func = partial( + backward_step_multimodule, + language_model_module_name=pg_collection.language_model_module_name, + ) + else: + backward_func = backward_step recv_tensor_shapes = get_tensor_shapes( seq_length=seq_length, diff --git a/megatron/core/pipeline_parallel/utils.py b/megatron/core/pipeline_parallel/utils.py index 0d7c21a6456..7930f315386 100644 --- a/megatron/core/pipeline_parallel/utils.py +++ b/megatron/core/pipeline_parallel/utils.py @@ -356,6 +356,7 @@ def backward_step_multimodule( output_tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], output_tensor_grad: Optional[Dict[str, torch.Tensor]], config, + language_model_module_name: str, ) -> Dict[str, torch.Tensor]: """Backward step for multi-module pipelines. @@ -370,6 +371,8 @@ def backward_step_multimodule( output_tensor: Dict mapping module names to output tensors, or scalar loss (last stage) output_tensor_grad: Dict mapping module names to output grads, or None (last stage) config: Model parallel configuration + language_model_module_name: Name of the language model module (e.g., 'llm'). + Used to associate scalar loss with the correct module at the terminal stage. Returns: Dict mapping module names to input tensor gradients @@ -378,7 +381,7 @@ def backward_step_multimodule( - Assumes each module operates independently (no cross-module gradients in forward) - Each module should have sequential pipeline stages (no cross-stage skip connections) - Encoder-decoder models with skip connections (e.g., T5) are not yet supported as LLM. - - Last stage: Scalar loss requires single-module; multi-module should return dict of losses + - Last stage: Scalar loss is associated with language_model_module_name. """ # Import locally to avoid circular dependency from megatron.core.pipeline_parallel.schedules import custom_backward @@ -390,12 +393,10 @@ def backward_step_multimodule( if tensor is not None: tensor.retain_grad() - # Last stage: output_tensor is a scalar loss, wrap in dict for uniform handling - # Assumes last stage only has one module (LLM) + # Last stage: output_tensor is a scalar loss from the language model. + # Associate it with the language_model_module_name. if not isinstance(output_tensor, dict): - all_keys = list(input_tensor.keys()) - main_module_key = all_keys[0] - output_tensor = {main_module_key: output_tensor} + output_tensor = {language_model_module_name: output_tensor} # Handle output_tensor_grad: None (last stage) or dict (intermediate stages) if not output_tensor_grad: diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 9fa0e080ccc..783ec5c8b96 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -572,58 +572,58 @@ def setup_process_groups_for_ddp( @dataclass -class ProcessGroupCollectionWrapper: - """Wrapper for multiple process group collections in multi-module pipelines. +class MultiModuleProcessGroupCollection: + """Process group collection for multi-module pipelines. Used when a rank participates in multiple modules (e.g., colocated encoder + LLM). - The language_model key identifies which module is the language model (used for - CP size extraction and other LLM-specific operations). + The language_model_module_name identifies which module is the language model (used for + CP size extraction, loss computation, and other LLM-specific operations). Attributes: module_collections: Dict mapping module names to ProcessGroupCollection objects - language_model: Key identifying the language model module (None if no LLM on this rank) + language_model_module_name: Key identifying the language model module (None if no LLM on this rank) Example: # Colocated rank with encoder and LLM - wrapper = ProcessGroupCollectionWrapper( + pg_collection = MultiModuleProcessGroupCollection( module_collections={ "encoder": encoder_pg, "llm": llm_pg }, - language_model="llm" + language_model_module_name="llm" ) # Rank with dual encoders (no LLM) - wrapper = ProcessGroupCollectionWrapper( + pg_collection = MultiModuleProcessGroupCollection( module_collections={ "encoder_1": encoder_1_pg, "encoder_2": encoder_2_pg }, - language_model=None + language_model_module_name=None ) # Single module (can also use ProcessGroupCollection directly) - wrapper = ProcessGroupCollectionWrapper( + pg_collection = MultiModuleProcessGroupCollection( module_collections={"llm": llm_pg}, - language_model="llm" + language_model_module_name="llm" ) # Usage - cp_size = wrapper.get_language_model_cp_size() - encoder_pg = wrapper["encoder_1"] # Dict-like access - has_llm = wrapper.has_language_model() + cp_size = pg_collection.get_language_model_cp_size() + encoder_pg = pg_collection["encoder_1"] # Dict-like access + has_llm = pg_collection.has_language_model() """ module_collections: Dict[str, ProcessGroupCollection] - language_model: Optional[str] = None + language_model_module_name: Optional[str] = None def __post_init__(self): if not self.module_collections: raise ValueError("module_collections dict cannot be empty") - if self.language_model is not None: - if self.language_model not in self.module_collections: + if self.language_model_module_name is not None: + if self.language_model_module_name not in self.module_collections: raise ValueError( - f"language_model '{self.language_model}' not found in " + f"language_model_module_name '{self.language_model_module_name}' not found in " f"module_collections keys: {list(self.module_collections.keys())}" ) @@ -634,11 +634,11 @@ def get_language_model_collection(self) -> ProcessGroupCollection: ProcessGroupCollection for the language model. Raises: - ValueError: If no language model is specified for this wrapper. + ValueError: If no language model is specified for this collection. """ - if self.language_model is None: - raise ValueError("No language model specified for this wrapper") - return self.module_collections[self.language_model] + if self.language_model_module_name is None: + raise ValueError("No language model specified for this collection") + return self.module_collections[self.language_model_module_name] def get_language_model_cp_size(self) -> int: """Get context parallel size for the language model. @@ -647,7 +647,7 @@ def get_language_model_cp_size(self) -> int: Context parallel size for the language model. Raises: - ValueError: If no language model is specified for this wrapper. + ValueError: If no language model is specified for this collection. """ return self.get_language_model_collection().cp.size() @@ -657,7 +657,7 @@ def has_language_model(self) -> bool: Returns: True if this rank has a language model, False otherwise. """ - return self.language_model is not None + return self.language_model_module_name is not None def get_module_collection(self, module_name: str) -> ProcessGroupCollection: """Get process group collection for a specific module. @@ -705,5 +705,5 @@ def items(self): def __repr__(self): """Return a concise representation showing modules and their language model status.""" modules_str = ', '.join(self.module_collections.keys()) - lm_str = f", language_model='{self.language_model}'" if self.language_model else "" - return f"ProcessGroupCollectionWrapper(modules=[{modules_str}]{lm_str})" + lm_str = f", language_model_module_name='{self.language_model_module_name}'" if self.language_model_module_name else "" + return f"MultiModuleProcessGroupCollection(modules=[{modules_str}]{lm_str})" From b102eb77eb2fb5e93532dc645788ae67e9dd5d90 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 15:25:57 -0800 Subject: [PATCH 08/30] rename module_collections to module_pgs for clarity --- megatron/core/process_groups_config.py | 46 +++++++++++--------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 783ec5c8b96..1513184891e 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -580,31 +580,25 @@ class MultiModuleProcessGroupCollection: CP size extraction, loss computation, and other LLM-specific operations). Attributes: - module_collections: Dict mapping module names to ProcessGroupCollection objects + module_pgs: Dict mapping module names to ProcessGroupCollection objects language_model_module_name: Key identifying the language model module (None if no LLM on this rank) Example: # Colocated rank with encoder and LLM pg_collection = MultiModuleProcessGroupCollection( - module_collections={ - "encoder": encoder_pg, - "llm": llm_pg - }, + module_pgs={"encoder": encoder_pg, "llm": llm_pg}, language_model_module_name="llm" ) # Rank with dual encoders (no LLM) pg_collection = MultiModuleProcessGroupCollection( - module_collections={ - "encoder_1": encoder_1_pg, - "encoder_2": encoder_2_pg - }, + module_pgs={"encoder_1": encoder_1_pg, "encoder_2": encoder_2_pg}, language_model_module_name=None ) # Single module (can also use ProcessGroupCollection directly) pg_collection = MultiModuleProcessGroupCollection( - module_collections={"llm": llm_pg}, + module_pgs={"llm": llm_pg}, language_model_module_name="llm" ) @@ -614,17 +608,17 @@ class MultiModuleProcessGroupCollection: has_llm = pg_collection.has_language_model() """ - module_collections: Dict[str, ProcessGroupCollection] + module_pgs: Dict[str, ProcessGroupCollection] language_model_module_name: Optional[str] = None def __post_init__(self): - if not self.module_collections: - raise ValueError("module_collections dict cannot be empty") + if not self.module_pgs: + raise ValueError("module_pgs dict cannot be empty") if self.language_model_module_name is not None: - if self.language_model_module_name not in self.module_collections: + if self.language_model_module_name not in self.module_pgs: raise ValueError( f"language_model_module_name '{self.language_model_module_name}' not found in " - f"module_collections keys: {list(self.module_collections.keys())}" + f"module_pgs keys: {list(self.module_pgs.keys())}" ) def get_language_model_collection(self) -> ProcessGroupCollection: @@ -638,7 +632,7 @@ def get_language_model_collection(self) -> ProcessGroupCollection: """ if self.language_model_module_name is None: raise ValueError("No language model specified for this collection") - return self.module_collections[self.language_model_module_name] + return self.module_pgs[self.language_model_module_name] def get_language_model_cp_size(self) -> int: """Get context parallel size for the language model. @@ -671,39 +665,39 @@ def get_module_collection(self, module_name: str) -> ProcessGroupCollection: Raises: ValueError: If module_name is not found in collections. """ - if module_name not in self.module_collections: + if module_name not in self.module_pgs: raise ValueError( f"Module '{module_name}' not found in collections. " - f"Available: {list(self.module_collections.keys())}" + f"Available: {list(self.module_pgs.keys())}" ) - return self.module_collections[module_name] + return self.module_pgs[module_name] def __len__(self): """Return the number of modules in this wrapper.""" - return len(self.module_collections) + return len(self.module_pgs) def __getitem__(self, module_name: str): """Get process group collection for a module using dict-like access.""" - return self.module_collections[module_name] + return self.module_pgs[module_name] def __iter__(self): """Iterate over all process group collections.""" - return iter(self.module_collections.values()) + return iter(self.module_pgs.values()) def keys(self): """Return module names.""" - return self.module_collections.keys() + return self.module_pgs.keys() def values(self): """Return process group collections.""" - return self.module_collections.values() + return self.module_pgs.values() def items(self): """Return (module_name, collection) pairs.""" - return self.module_collections.items() + return self.module_pgs.items() def __repr__(self): """Return a concise representation showing modules and their language model status.""" - modules_str = ', '.join(self.module_collections.keys()) + modules_str = ', '.join(self.module_pgs.keys()) lm_str = f", language_model_module_name='{self.language_model_module_name}'" if self.language_model_module_name else "" return f"MultiModuleProcessGroupCollection(modules=[{modules_str}]{lm_str})" From ebbb50980a80b0e36e471b667e1c9fda39cd3944 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Wed, 28 Jan 2026 15:40:41 -0800 Subject: [PATCH 09/30] rename tensor conversion functions for clarity --- .../multimodule_communicator.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index 7f79bc59de4..ce10e5888c0 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -45,10 +45,10 @@ class RankModuleInfo: is_terminal_stage: Optional[bool] = True -def _ensure_3d_tensor( +def _prepare_tensor_for_comm( tensor: Union[torch.Tensor, List[torch.Tensor], None] ) -> Union[torch.Tensor, List[torch.Tensor], None]: - """Ensure tensor is 3D for P2P/bridge communication. + """Prepare tensor for P2P/bridge communication by expanding to 3D if needed. P2P and bridge communicators expect 3D tensors. Handles both single tensors and lists of tensors (for VPP). @@ -62,18 +62,18 @@ def _ensure_3d_tensor( if tensor is None: return None if isinstance(tensor, list): - return [_ensure_3d_tensor(t) for t in tensor] + return [_prepare_tensor_for_comm(t) for t in tensor] if isinstance(tensor, torch.Tensor) and tensor.ndim == 2: return tensor.unsqueeze(-1) return tensor -def _restore_tensor_shape( +def _restore_tensor_from_comm( tensor: Union[torch.Tensor, List[torch.Tensor], None] ) -> Union[torch.Tensor, List[torch.Tensor], None]: - """Restore original tensor shape after P2P/bridge communication. + """Restore tensor shape after P2P/bridge communication by squeezing singleton dim. - Remove the extra dimension added by _ensure_3d_tensor if it was singleton. + Removes the extra dimension added by _prepare_tensor_for_comm if it was singleton. Handles both single tensors and lists of tensors (for VPP). Args: @@ -85,7 +85,7 @@ def _restore_tensor_shape( if tensor is None: return None if isinstance(tensor, list): - return [_restore_tensor_shape(t) for t in tensor] + return [_restore_tensor_from_comm(t) for t in tensor] if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1: return tensor.squeeze(-1) return tensor @@ -313,13 +313,13 @@ def recv_forward( # from incoming modules. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: received_tensor = bridge_comm.recv_forward() - input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(received_tensor) + input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor) else: # If not first stage, receive forward activation tensor from P2P communicator. received_tensor = rank_module_info.p2p_communicator.recv_forward( tensor_shapes=tensor_shape, is_first_stage=False ) - input_dict[module_name] = _restore_tensor_shape(received_tensor) + input_dict[module_name] = _restore_tensor_from_comm(received_tensor) return input_dict def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool = False): @@ -333,11 +333,11 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool # If last stage, and has outgoing modules, send forward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) bridge_comm.send_forward(tensor_to_send) else: # If not last stage, send forward activation by using P2P communicator. - tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) rank_module_info.p2p_communicator.send_forward( tensor_to_send, is_last_stage=False ) @@ -363,17 +363,17 @@ def send_forward_recv_backward( # If last stage, and has outgoing modules, send forward activation and # receive backward gradient by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) grad = bridge_comm.send_forward_recv_backward(tensor_to_send) - grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) + grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad) else: # If not last stage, send forward activation and receive backward gradient # by using P2P communicator. - tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) grad = rank_module_info.p2p_communicator.send_forward_recv_backward( tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False ) - grad_dict[module_name] = _restore_tensor_shape(grad) + grad_dict[module_name] = _restore_tensor_from_comm(grad) return grad_dict def send_backward_recv_forward( @@ -397,17 +397,17 @@ def send_backward_recv_forward( for bridge_comm in rank_module_info.bridge_comms_as_dest_module: # If first stage, and has incoming modules, send backward gradient and # receive forward activation by using bridge communicator. - grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name]) received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) - input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(received_tensor) + input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor) else: # If not first stage, send backward gradient and receive forward activation # by using P2P communicator. - grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name]) received_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward( grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False ) - input_dict[module_name] = _restore_tensor_shape(received_tensor) + input_dict[module_name] = _restore_tensor_from_comm(received_tensor) return input_dict def recv_backward( @@ -432,13 +432,13 @@ def recv_backward( # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: grad = bridge_comm.recv_backward() - grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) + grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad) else: # If not last stage, receive backward gradient by using P2P communicator. grad = rank_module_info.p2p_communicator.recv_backward( tensor_shapes=tensor_shape, is_last_stage=False ) - grad_dict[module_name] = _restore_tensor_shape(grad) + grad_dict[module_name] = _restore_tensor_from_comm(grad) return grad_dict def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False): @@ -452,11 +452,11 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool # If first stage, and has incoming modules, send backward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name]) bridge_comm.send_backward(grad_to_send) else: # If not first stage, send backward activation by using P2P communicator. - grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name]) rank_module_info.p2p_communicator.send_backward( grad_to_send, is_first_stage=False ) From 0b6cefda01f4dbef44058692554c65935dae6a15 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 20:27:02 -0800 Subject: [PATCH 10/30] Fix linting issues: format code and remove unused imports --- .../pipeline_parallel/bridge_communicator.py | 4 +- .../multimodule_communicator.py | 16 ++-- megatron/core/pipeline_parallel/schedules.py | 39 ++++----- megatron/core/process_groups_config.py | 9 ++- .../test_multimodule_schedules.py | 79 ++++++++++++------- 5 files changed, 84 insertions(+), 63 deletions(-) diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index df206f6f7e6..8ebccb2c1dc 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -737,7 +737,9 @@ def send_backward_recv_forward( req.wait() # Concatenate received activations - aggregated_activation = torch.cat(received_activations_list, dim=self.dim_mapping['b']) + aggregated_activation = torch.cat( + received_activations_list, dim=self.dim_mapping['b'] + ) logging.debug( f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index ce10e5888c0..883ef5c0ddd 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -313,7 +313,9 @@ def recv_forward( # from incoming modules. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: received_tensor = bridge_comm.recv_forward() - input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor) + input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm( + received_tensor + ) else: # If not first stage, receive forward activation tensor from P2P communicator. received_tensor = rank_module_info.p2p_communicator.recv_forward( @@ -338,9 +340,7 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool else: # If not last stage, send forward activation by using P2P communicator. tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) - rank_module_info.p2p_communicator.send_forward( - tensor_to_send, is_last_stage=False - ) + rank_module_info.p2p_communicator.send_forward(tensor_to_send, is_last_stage=False) def send_forward_recv_backward( self, @@ -399,7 +399,9 @@ def send_backward_recv_forward( # receive forward activation by using bridge communicator. grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name]) received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) - input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(received_tensor) + input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm( + received_tensor + ) else: # If not first stage, send backward gradient and receive forward activation # by using P2P communicator. @@ -457,9 +459,7 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool else: # If not first stage, send backward activation by using P2P communicator. grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name]) - rank_module_info.p2p_communicator.send_backward( - grad_to_send, is_first_stage=False - ) + rank_module_info.p2p_communicator.send_backward(grad_to_send, is_first_stage=False) @staticmethod def compute_total_pipeline_stages( diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index fc9ed2f9db4..08dac7ed58a 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2,7 +2,7 @@ import contextlib from functools import partial -from typing import Callable, Dict, Iterator, List, Optional, Union +from typing import Callable, Iterator, List, Optional, Union import torch from torch.autograd.variable import Variable @@ -12,9 +12,7 @@ from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) -from megatron.core.pipeline_parallel.multimodule_communicator import ( - MultiModulePipelineCommunicator, -) +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( backward_step_multimodule, @@ -24,8 +22,8 @@ is_vp_last_stage, ) from megatron.core.process_groups_config import ( - ProcessGroupCollection, MultiModuleProcessGroupCollection, + ProcessGroupCollection, ) from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.enums import CudaGraphScope @@ -673,9 +671,7 @@ def forward_backward_no_pipelining( ) total_num_tokens += num_tokens if not forward_only: - backward_step( - input_tensor, output_tensor, output_tensor_grad, config - ) + backward_step(input_tensor, output_tensor, output_tensor_grad, config) # Run computation for last microbatch out of context handler (want to # synchronize gradients). output_tensor, num_tokens = forward_step( @@ -1326,9 +1322,7 @@ def backward_step_helper(virtual_microbatch_id): virtual_microbatch_id, model_chunk_id ) - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, config - ) + input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, config) backward_step_helper_postprocess(virtual_microbatch_id) @@ -2016,7 +2010,9 @@ def forward_backward_pipelining_without_interleaving( first_val_step: Optional[bool] = None, adjust_tensor_shapes_fn: Optional[Callable] = None, p2p_communicator: Optional[P2PCommunicator] = None, - pg_collection: Optional[Union[ProcessGroupCollection, MultiModuleProcessGroupCollection]] = None, + pg_collection: Optional[ + Union[ProcessGroupCollection, MultiModuleProcessGroupCollection] + ] = None, force_all_reduce: Optional[bool] = False, ): """Run non-interleaved 1F1B schedule, with communication between pipeline @@ -2041,7 +2037,8 @@ def forward_backward_pipelining_without_interleaving( tp_group, cp_group, cp_size = None, None, None - # Determine if this is a multi-module pipeline (used for validation and backward function selection) + # Determine if this is a multi-module pipeline + # (used for validation and backward function selection) is_multimodule = isinstance(pg_collection, MultiModuleProcessGroupCollection) or isinstance( p2p_communicator, MultiModulePipelineCommunicator ) @@ -2091,8 +2088,8 @@ def forward_backward_pipelining_without_interleaving( else: raise TypeError( - f"pg_collection must be ProcessGroupCollection or MultiModuleProcessGroupCollection, " - f"got {type(pg_collection)}" + f"pg_collection must be ProcessGroupCollection or " + f"MultiModuleProcessGroupCollection, got {type(pg_collection)}" ) else: @@ -2263,9 +2260,7 @@ def enable_grad_sync(): total_num_tokens += num_tokens if forward_only: - p2p_communicator.send_forward( - output_tensor, p2p_communicator.is_pp_last_stage - ) + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) if not last_iteration: input_tensor = p2p_communicator.recv_forward( recv_tensor_shapes, p2p_communicator.is_pp_first_stage @@ -2302,9 +2297,7 @@ def enable_grad_sync(): ) else: input_tensor = p2p_communicator.send_backward_recv_forward( - input_tensor_grad, - recv_tensor_shapes, - p2p_communicator.is_pp_first_stage, + input_tensor_grad, recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run cooldown backward passes. @@ -2331,9 +2324,7 @@ def enable_grad_sync(): input_tensor, output_tensor, output_tensor_grad, config ) - p2p_communicator.send_backward( - input_tensor_grad, p2p_communicator.is_pp_first_stage - ) + p2p_communicator.send_backward(input_tensor_grad, p2p_communicator.is_pp_first_stage) # Launch any remaining grad reductions. if no_sync_context is not None: diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 1513184891e..a1afaa96513 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -581,7 +581,8 @@ class MultiModuleProcessGroupCollection: Attributes: module_pgs: Dict mapping module names to ProcessGroupCollection objects - language_model_module_name: Key identifying the language model module (None if no LLM on this rank) + language_model_module_name: Key identifying the language model module + (None if no LLM on this rank) Example: # Colocated rank with encoder and LLM @@ -699,5 +700,9 @@ def items(self): def __repr__(self): """Return a concise representation showing modules and their language model status.""" modules_str = ', '.join(self.module_pgs.keys()) - lm_str = f", language_model_module_name='{self.language_model_module_name}'" if self.language_model_module_name else "" + lm_str = ( + f", language_model_module_name='{self.language_model_module_name}'" + if self.language_model_module_name + else "" + ) return f"MultiModuleProcessGroupCollection(modules=[{modules_str}]{lm_str})" diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py index 9df3037da70..80e6bad9362 100644 --- a/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py @@ -16,9 +16,7 @@ from megatron.core.distributed.finalize_model_grads import finalize_model_grads from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.pipeline_parallel.multimodule_communicator import ( - MultiModulePipelineCommunicator, -) +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -106,11 +104,13 @@ def create_transformer_block(hidden_size, pg_collection, dtype=torch.bfloat16): bf16=(dtype == torch.bfloat16), ) - block = TransformerBlock( - config, - get_gpt_layer_with_transformer_engine_spec(), - pg_collection=pg_collection, - ).cuda().to(dtype) + block = ( + TransformerBlock( + config, get_gpt_layer_with_transformer_engine_spec(), pg_collection=pg_collection + ) + .cuda() + .to(dtype) + ) with torch.no_grad(): for mod in block.modules(): @@ -166,16 +166,18 @@ def __init__(self, encoder_configs, llm_config, hidden_size): for enc_cfg in encoder_configs: name = enc_cfg['name'] module, grid = create_module_with_grid( - enc_cfg['tp'], enc_cfg['pp'], enc_cfg['dp'], - enc_cfg['grid_offset'], hidden_size + enc_cfg['tp'], enc_cfg['pp'], enc_cfg['dp'], enc_cfg['grid_offset'], hidden_size ) self.encoders[name] = module self.encoder_grids[name] = grid # Create LLM self.llm, self.llm_grid = create_module_with_grid( - llm_config['tp'], llm_config['pp'], llm_config['dp'], - llm_config['grid_offset'], hidden_size + llm_config['tp'], + llm_config['pp'], + llm_config['dp'], + llm_config['grid_offset'], + hidden_size, ) # Track all modules for gradient sync @@ -274,8 +276,7 @@ def forward(self, hidden_states): if encoder is not None and self.is_rank_in_grid(self.encoder_grids[name]): pp_group = self.encoder_grids[name].get_pg("pp") input_tensor = ( - hidden_states if is_pp_first_stage(pp_group) - else self.input_tensors[name] + hidden_states if is_pp_first_stage(pp_group) else self.input_tensors[name] ) output_dict[name] = encoder(input_tensor, attention_mask=None) @@ -304,8 +305,11 @@ def __iter__(self): def __next__(self): return torch.randn( - self.seq_length, self.micro_batch_size, self.hidden_size, - device='cuda', dtype=torch.bfloat16 + self.seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, ) @@ -315,8 +319,7 @@ def __next__(self): def run_multimodule_schedule_test( - encoder_configs, llm_config, hidden_size, seq_length, - micro_batch_size, num_microbatches + encoder_configs, llm_config, hidden_size, seq_length, micro_batch_size, num_microbatches ): """Run multimodule schedule test with given configuration. @@ -352,7 +355,8 @@ def run_multimodule_schedule_test( config.finalize_model_grads_func = model.finalize_model_grads config.grad_scale_func = lambda loss: ( torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) - if isinstance(loss, (int, float)) else loss + if isinstance(loss, (int, float)) + else loss ) config.hidden_size = hidden_size model.config = config @@ -377,7 +381,10 @@ def run_multimodule_schedule_test( if grid.rank_offset <= rank < grid.rank_offset + grid.size: pg_collection = add_embedding_groups(get_pg_collection(grid)) break - if pg_collection is None and model.llm_grid.rank_offset <= rank < model.llm_grid.rank_offset + model.llm_grid.size: + if ( + pg_collection is None + and model.llm_grid.rank_offset <= rank < model.llm_grid.rank_offset + model.llm_grid.size + ): pg_collection = add_embedding_groups(get_pg_collection(model.llm_grid)) # Define step function @@ -442,8 +449,12 @@ def test_single_encoder_2gpu(self): llm_config = {'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 1} run_multimodule_schedule_test( - encoder_configs, llm_config, - hidden_size=512, seq_length=64, micro_batch_size=2, num_microbatches=4 + encoder_configs, + llm_config, + hidden_size=512, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, ) def test_dual_encoder_2gpu(self): @@ -458,8 +469,12 @@ def test_dual_encoder_2gpu(self): llm_config = {'tp': 1, 'pp': 1, 'dp': 1, 'grid_offset': 1} run_multimodule_schedule_test( - encoder_configs, llm_config, - hidden_size=512, seq_length=64, micro_batch_size=2, num_microbatches=4 + encoder_configs, + llm_config, + hidden_size=512, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, ) def test_single_encoder_8gpu(self): @@ -471,8 +486,12 @@ def test_single_encoder_8gpu(self): llm_config = {'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 4} run_multimodule_schedule_test( - encoder_configs, llm_config, - hidden_size=1024, seq_length=512, micro_batch_size=4, num_microbatches=16 + encoder_configs, + llm_config, + hidden_size=1024, + seq_length=512, + micro_batch_size=4, + num_microbatches=16, ) def test_dual_encoder_8gpu(self): @@ -487,6 +506,10 @@ def test_dual_encoder_8gpu(self): llm_config = {'tp': 2, 'pp': 2, 'dp': 1, 'grid_offset': 4} run_multimodule_schedule_test( - encoder_configs, llm_config, - hidden_size=1024, seq_length=512, micro_batch_size=4, num_microbatches=16 + encoder_configs, + llm_config, + hidden_size=1024, + seq_length=512, + micro_batch_size=4, + num_microbatches=16, ) From 7d566d99759cc8f99b444093878bf9c9b8782e92 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 14:23:21 -0800 Subject: [PATCH 11/30] Add RankRole and ModuleStageInfo for multi-module pipeline parallelism Introduce data classes to manage rank roles in multi-module PP setups: - ModuleStageInfo: tracks first/last stage position within a module - RankRole: tracks which modules a rank participates in and their stages These classes enable selective module initialization and stage-aware forward passes when different modules run on separate PP grids. Signed-off-by: ykarnati --- megatron/core/models/mimo/config/__init__.py | 3 +- megatron/core/models/mimo/config/role.py | 65 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 megatron/core/models/mimo/config/role.py diff --git a/megatron/core/models/mimo/config/__init__.py b/megatron/core/models/mimo/config/__init__.py index 8371675a22d..3da744a6fb2 100644 --- a/megatron/core/models/mimo/config/__init__.py +++ b/megatron/core/models/mimo/config/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole -__all__ = ['MimoModelConfig'] +__all__ = ['MimoModelConfig', 'ModuleStageInfo', 'RankRole'] diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py new file mode 100644 index 00000000000..6a4650c45c4 --- /dev/null +++ b/megatron/core/models/mimo/config/role.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Data classes for MIMO rank role management in multi-module pipeline parallelism.""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + + +@dataclass +class ModuleStageInfo: + """Information about a rank's stage position within a module's pipeline. + + Args: + is_first_stage: True if this rank is the first PP stage for this module. + is_last_stage: True if this rank is the last PP stage for this module. + """ + + is_first_stage: bool + is_last_stage: bool + + +@dataclass +class RankRole: + """Describes what modules this rank participates in for multi-module PP. + + This class captures the role of a specific rank in a multi-module pipeline + parallel setup, tracking which modules the rank participates in and their + stage positions. + + Args: + modules: Dict mapping module names to their stage info for modules + this rank participates in. + language_module_name: Name of the language module, used to distinguish + encoders from the language model. + """ + + modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) + language_module_name: Optional[str] = None + + @property + def has_modality_modules(self) -> bool: + """Return True if this rank participates in any modality (non-language) module.""" + return any(name != self.language_module_name for name in self.modules) + + @property + def has_language_module(self) -> bool: + """Return True if this rank participates in the language module.""" + return self.language_module_name is not None and self.language_module_name in self.modules + + @property + def modality_module_names(self) -> List[str]: + """Return names of modality modules (non-language) this rank participates in.""" + return [name for name in self.modules if name != self.language_module_name] + + def is_first_stage(self, module_name: str) -> bool: + """Check if this rank is the first stage for a given module.""" + if module_name not in self.modules: + return False + return self.modules[module_name].is_first_stage + + def is_last_stage(self, module_name: str) -> bool: + """Check if this rank is the last stage for a given module.""" + if module_name not in self.modules: + return False + return self.modules[module_name].is_last_stage From 997dfa5ee06f76451d720926b1889358b53a0307 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 14:24:57 -0800 Subject: [PATCH 12/30] Add stage-aware forward pass to modality submodules Enable modality submodules to operate in multi-stage PP configurations: - Add is_first_stage/is_last_stage as immutable properties - First stage: runs encoder on raw inputs - Intermediate stages: pass through hidden states - Last stage: applies input projection before language model Update from_spec() to pass stage info through constructor for proper initialization based on pipeline position. Signed-off-by: ykarnati --- megatron/core/models/mimo/submodules/audio.py | 52 +++++++++------ megatron/core/models/mimo/submodules/base.py | 65 +++++++++++++++++-- .../core/models/mimo/submodules/vision.py | 46 +++++++++---- 3 files changed, 123 insertions(+), 40 deletions(-) diff --git a/megatron/core/models/mimo/submodules/audio.py b/megatron/core/models/mimo/submodules/audio.py index ae907d7ac86..5b4910a2989 100644 --- a/megatron/core/models/mimo/submodules/audio.py +++ b/megatron/core/models/mimo/submodules/audio.py @@ -126,30 +126,42 @@ def project_embeddings( return embeddings - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: - """Forward pass for audio modality submodules. + def forward( + self, + encoder_inputs: Optional[Dict[str, Any]] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + """Process audio data through encoding and projection. Args: encoder_inputs: Dictionary where keys match encoder names in self.encoders and values are dictionaries of encoder-specific parameters. - Example: { - "whisper": {"input_features": features}, - "wav2vec": {"input_values": waveform} - } + Used when is_first_stage=True. + hidden_states: Hidden states from previous pipeline stage. + Used when is_first_stage=False. Returns: - Flattened audio embeddings with shape [total_embeddings, hidden_dim], - or None if no valid inputs were provided. + - If is_last_stage: projected embeddings ready for language model + - If not is_last_stage: hidden states for next pipeline stage + - None if no valid input provided """ - - embeddings = self.encode(encoder_inputs) - # embeddings is a list of tensors, each tensor is a flattened audio embedding - - # If no embeddings were produced, return None - if not embeddings: - return None - - # Project embeddings - projected = self.project_embeddings(embeddings, is_input=True) - logger.debug(f"Projected audio embeddings shape: {projected.shape}") - return projected # [total_embeddings, hidden_dim] + # Determine input based on stage position + if self.is_first_stage: + if encoder_inputs is None: + return None + # Encode the audio + embeddings = self.encode(encoder_inputs) + if not embeddings: + return None + combined = self.combine_embeddings(embeddings) + else: + if hidden_states is None: + return None + # Use hidden states from previous stage + combined = hidden_states + + # Project only if last stage + if self.is_last_stage: + return self.project_embeddings([combined], is_input=True) + else: + return combined diff --git a/megatron/core/models/mimo/submodules/base.py b/megatron/core/models/mimo/submodules/base.py index 8b11ba7fcb9..a52c9cf1cc6 100644 --- a/megatron/core/models/mimo/submodules/base.py +++ b/megatron/core/models/mimo/submodules/base.py @@ -42,15 +42,30 @@ def __init__( decoders: Optional[Dict[str, nn.Module]] = None, input_projections: Optional[List[nn.Module]] = None, output_projections: Optional[List[nn.Module]] = None, + is_first_stage: bool = True, + is_last_stage: bool = True, **kwargs, ) -> None: - """Initialize the modality submodules.""" + """Initialize the modality submodules. + + Args: + encoders: Dict of encoder modules + decoders: Dict of decoder modules + input_projections: List of input projection modules + output_projections: List of output projection modules + is_first_stage: Whether this is the first PP stage for this module + is_last_stage: Whether this is the last PP stage for this module + """ super().__init__() self.encoders = nn.ModuleDict(encoders or {}) self.decoders = nn.ModuleDict(decoders or {}) self.input_projections = nn.ModuleList(input_projections or []) self.output_projections = nn.ModuleList(output_projections or []) + # Stage info for multi-module pipeline parallelism (immutable after init) + self._is_first_stage: bool = is_first_stage + self._is_last_stage: bool = is_last_stage + warnings.warn( "ModalitySubmodules is experimental and still under active development. " "The API may change without notice in future releases.", @@ -58,21 +73,44 @@ def __init__( stacklevel=2, ) + @property + def is_first_stage(self) -> bool: + """Whether this is the first pipeline stage for this module.""" + return self._is_first_stage + + @property + def is_last_stage(self) -> bool: + """Whether this is the last pipeline stage for this module.""" + return self._is_last_stage + @classmethod - def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': + def from_spec( + cls, + module_spec: ModuleSpec, + is_first_stage: bool = True, + is_last_stage: bool = True, + ) -> 'ModalitySubmodules': """Create a modality submodule from ModuleSpec configuration. Args: module_spec (ModuleSpec): The module specification for this modality submodule + is_first_stage (bool): Whether this is the first pipeline stage for this module. + Controls encoder initialization. Defaults to True. + is_last_stage (bool): Whether this is the last pipeline stage for this module. + Controls input projection initialization (only needed on last stage). + Defaults to True. Returns: ModalitySubmodules: An instance of the modality submodule """ - logger.debug(f"Creating {cls.__name__} from spec") + logger.debug( + f"Creating {cls.__name__} from spec (is_first_stage={is_first_stage}, " + f"is_last_stage={is_last_stage})" + ) params = module_spec.params or {} submodules = module_spec.submodules or {} - # Build component lists from submodules dictionary + # Build encoders (needed on all stages for pipeline processing) encoders = {} if 'encoders' in submodules: for encoder_name, encoder_spec in submodules['encoders'].items(): @@ -80,6 +118,7 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': encoder = build_module(encoder_spec) encoders[encoder_name] = encoder + # Build decoders (needed on all stages for pipeline processing) decoders = {} if 'decoders' in submodules: for decoder_name, decoder_spec in submodules['decoders'].items(): @@ -87,23 +126,35 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': decoder = build_module(decoder_spec) decoders[decoder_name] = decoder + # Build input projections only on last stage + # (projection happens after encoding, before sending to language model) input_projections = [] - if 'input_projections' in submodules: + if is_last_stage and 'input_projections' in submodules: for proj_spec in submodules['input_projections']: logger.debug( f"Building {cls.__name__} input projection: {proj_spec.module.__name__}" ) projection = build_module(proj_spec) input_projections.append(projection) + elif 'input_projections' in submodules: + logger.debug( + f"Skipping {cls.__name__} input projections (not last stage)" + ) + # Build output projections only on first stage + # (projection happens before decoding, after receiving from language model) output_projections = [] - if 'output_projections' in submodules: + if is_first_stage and 'output_projections' in submodules: for proj_spec in submodules['output_projections']: logger.debug( f"Building {cls.__name__} output projection: {proj_spec.module.__name__}" ) projection = build_module(proj_spec) output_projections.append(projection) + elif 'output_projections' in submodules: + logger.debug( + f"Skipping {cls.__name__} output projections (not first stage)" + ) # Pass any additional parameters from the params dictionary additional_params = params.copy() @@ -117,6 +168,8 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': decoders=decoders, input_projections=input_projections, output_projections=output_projections, + is_first_stage=is_first_stage, + is_last_stage=is_last_stage, **additional_params, ) diff --git a/megatron/core/models/mimo/submodules/vision.py b/megatron/core/models/mimo/submodules/vision.py index 795cb18a119..dfd22e2b392 100644 --- a/megatron/core/models/mimo/submodules/vision.py +++ b/megatron/core/models/mimo/submodules/vision.py @@ -40,6 +40,7 @@ def __init__( decoders=decoders, input_projections=input_projections, output_projections=output_projections, + **kwargs, ) if self.input_projections: @@ -160,25 +161,42 @@ def project_embeddings( return embeddings - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + def forward( + self, + encoder_inputs: Optional[Dict[str, Any]] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: """Process image data through encoding and projection. Args: encoder_inputs: Dictionary where keys match encoder names in self.encoders and values are dictionaries of encoder-specific parameters. - Example: {"clip": {"pixel_values": images}, "vit": {"images": vit_images}} + Used when is_first_stage=True. + hidden_states: Hidden states from previous pipeline stage. + Used when is_first_stage=False. Returns: - Flattened image embeddings with shape [total_embeddings, hidden_dim], - or None if no valid inputs were provided. + - If is_last_stage: projected embeddings ready for language model + - If not is_last_stage: hidden states for next pipeline stage + - None if no valid input provided """ - # Encode the images - embeddings = self.encode(encoder_inputs) - - # If no embeddings were produced, return None - if not embeddings: - return None - - projected = self.project_embeddings(embeddings, is_input=True) - logging.debug(f"Projected audio embeddings shape: {projected.shape}") - return projected # [total_embeddings, hidden_dim] + # Determine input based on stage position + if self.is_first_stage: + if encoder_inputs is None: + return None + # Encode the images + embeddings = self.encode(encoder_inputs) + if not embeddings: + return None + combined = self.combine_embeddings(embeddings) + else: + if hidden_states is None: + return None + # Use hidden states from previous stage + combined = hidden_states + + # Project only if last stage + if self.is_last_stage: + return self.project_embeddings([combined], is_input=True) + else: + return combined From a1a8fdc5a086d52805d2dd0b575c6ca1b5f15c3d Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 14:54:06 -0800 Subject: [PATCH 13/30] Update MimoModel for multi-module pipeline parallelism Add support for running encoder and language modules on separate PP grids: - Determine rank role based on module_to_grid_map configuration - Selective module initialization based on role (encoder-only or LM-only) - Stage-aware forward dispatching based on role - Validate grid map configuration requires language_module_key The forward pass now routes to _forward_encoders or _forward_language_module based on the rank's assigned role in the multi-module PP setup. Signed-off-by: ykarnati --- .../core/models/mimo/config/base_configs.py | 11 +- megatron/core/models/mimo/model/base.py | 337 ++++++++++++++++-- megatron/core/process_groups_config.py | 11 +- 3 files changed, 323 insertions(+), 36 deletions(-) diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index 8b170abe152..8dc2124ba08 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -2,7 +2,7 @@ import warnings from dataclasses import dataclass, field -from typing import Dict +from typing import Any, Dict, Optional from megatron.core.transformer.spec_utils import ModuleSpec @@ -20,6 +20,13 @@ class MimoModelConfig: Dictionary mapping modality names to their special token IDs. For example, {"vision": -200, "audio":32000}, these represent placeholders in the input_ids to insert the modality embeddings at the correct positions. + module_to_grid_map (Optional[Dict[str, Any]]): + Dictionary mapping module keys (e.g., "vision", "language") to their + corresponding grid configurations for non-colocated pipeline parallelism. + When None, all modules are assumed to be colocated on the same ranks. + language_module_key (Optional[str]): + The key used to identify the language module in the module_to_grid_map. + Required when module_to_grid_map is provided. """ warnings.warn( @@ -32,3 +39,5 @@ class MimoModelConfig: language_model_spec: ModuleSpec = field(default_factory=ModuleSpec) modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) special_token_ids: Dict[str, int] = field(default_factory=dict) + module_to_grid_map: Optional[Dict[str, Any]] = None + language_module_key: Optional[str] = None diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 2f136a98466..dc51237084a 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -5,10 +5,14 @@ from typing import Any, Dict, Optional import torch +import torch.distributed as dist + +from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole from megatron.core.models.mimo.config import MimoModelConfig from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import unwrap_model logger = logging.getLogger(__name__) @@ -52,6 +56,8 @@ def __init__(self, mimo_config: MimoModelConfig) -> None: ) self.mimo_config = mimo_config + self._validate_grid_map() + self.role = self._determine_role() # Use special token IDs from the config self.special_token_ids = ( @@ -138,25 +144,135 @@ def _initialize_submodules(self) -> None: """Initialize modality submodules from the ModuleSpec configurations. Only modalities present in the config will be instantiated. - For each modality in the config, builds the corresponding submodule using from_spec. + When role is set, only initializes submodules this rank participates in. + Stage info is passed to from_spec() to conditionally skip projection + initialization on non-last stages (saves memory in pipeline parallelism). """ - for modality_name, submodule_spec in self.mimo_config.modality_submodules_spec.items(): - # Get the submodule class + # Skip if we have a role and this module isn't in it + if self.role is not None and modality_name not in self.role.modules: + logger.debug(f"Skipping {modality_name} submodule (not in role)") + continue + + # Determine stage info for this module + is_first_stage = True + is_last_stage = True + if self.role is not None and modality_name in self.role.modules: + stage_info = self.role.modules[modality_name] + is_first_stage = stage_info.is_first_stage + is_last_stage = stage_info.is_last_stage + submodule_class = submodule_spec.module - logger.debug(f"Building {modality_name} submodule using {submodule_class.__name__}") + logger.debug( + f"Building {modality_name} submodule using {submodule_class.__name__} " + f"(is_first_stage={is_first_stage}, is_last_stage={is_last_stage})" + ) + + # Pass stage info to from_spec so projections are only built when needed + submodule = submodule_class.from_spec( + submodule_spec, + is_first_stage=is_first_stage, + is_last_stage=is_last_stage, + ) - # Use from_spec to instantiate the submodule - submodule = submodule_class.from_spec(submodule_spec) self.modality_submodules[modality_name] = submodule def _initialize_language_model(self) -> None: - """Initialize the language model.""" + """Initialize the language model. + + When role is set, only initializes if this rank participates in language module. + """ + # Skip if we have a role and don't participate in language module + if self.role is not None and not self.role.has_language_module: + logger.debug("Skipping language model initialization (not in role)") + self.language_model = None + return + logger.debug( f"Building language model using {self.mimo_config.language_model_spec.module.__name__}" ) self.language_model = build_module(self.mimo_config.language_model_spec) + def _validate_grid_map(self) -> None: + """Validate module_to_grid_map consistency with submodule config. + + Validates that: + - language_module_key is set when module_to_grid_map is provided + - module_to_grid_map keys exactly match modality_submodules_spec keys + language_module_key + + Raises: + ValueError: If validation fails. + """ + if not self.mimo_config.module_to_grid_map: + return + + # Require language_module_key when using multi-module PP + if self.mimo_config.language_module_key is None: + raise ValueError( + "language_module_key must be set when module_to_grid_map is provided. " + "Specify which module key identifies the language model." + ) + + grid_map_keys = set(self.mimo_config.module_to_grid_map.keys()) + submodule_keys = set(self.mimo_config.modality_submodules_spec.keys()) + submodule_keys.add(self.mimo_config.language_module_key) + + if grid_map_keys != submodule_keys: + missing_in_grid = submodule_keys - grid_map_keys + extra_in_grid = grid_map_keys - submodule_keys + raise ValueError( + f"module_to_grid_map keys must match modality_submodules_spec keys + " + f"language_module_key. Missing in grid_map: {missing_in_grid}, " + f"Extra in grid_map: {extra_in_grid}" + ) + + def _determine_role(self) -> Optional[RankRole]: + """Determine this rank's role based on grid map. + + Returns: + RankRole describing which modules this rank participates in, + or None if module_to_grid_map is not set (all modules on all ranks). + """ + if not self.mimo_config.module_to_grid_map: + return None + + current_rank = dist.get_rank() + modules = {} + + for module_name, grid in self.mimo_config.module_to_grid_map.items(): + # Check if current rank is in this grid + if not (grid.rank_offset <= current_rank < grid.rank_offset + grid.size): + continue + + # Check if PP dimension exists + if "pp" not in grid.dim_names: + # No PP dimension means single stage (both first and last) + modules[module_name] = ModuleStageInfo( + is_first_stage=True, + is_last_stage=True, + ) + continue + + # Get PP process group and determine stage + pp_group = grid.get_pg("pp") + pp_rank = pp_group.rank() + pp_size = pp_group.size() + is_first = (pp_rank == 0) + is_last = (pp_rank == pp_size - 1) + logger.info( + f"[_determine_role] Rank {current_rank}: module={module_name}, " + f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}" + ) + modules[module_name] = ModuleStageInfo( + is_first_stage=is_first, + is_last_stage=is_last, + ) + + return RankRole( + modules=modules, + language_module_name=self.mimo_config.language_module_key, + ) + def set_input_tensor(self, input_tensor): """Set input tensor for pipeline parallelism. @@ -164,18 +280,27 @@ def set_input_tensor(self, input_tensor): It passes the output tensor from the previous stage as input to this stage. Args: - input_tensor: Tensor or list of tensors passed between pipeline stages + input_tensor: Either: + - Dict[str, Tensor]: Maps module names to their input tensors (for multi-module PP) + - Tensor or List[Tensor]: Single tensor for language model (backward compat) Returns: None """ - # Handle case where input_tensor might be a list or a single tensor + # Store dict input for multi-module PP + if isinstance(input_tensor, dict): + self.input_tensors = input_tensor + return + + # Backward compatibility: single tensor or list if isinstance(input_tensor, list): - # For simplicity, just use the first tensor input_tensor = input_tensor[0] - # Pass the input tensor to the language model if it has a set_input_tensor method - if hasattr(self.language_model, 'set_input_tensor'): + # Store as input_tensors for consistency + self.input_tensors = input_tensor + + # Also delegate to language model for backward compatibility + if self.language_model is not None and hasattr(self.language_model, 'set_input_tensor'): self.language_model.set_input_tensor(input_tensor) def get_text_embeddings( @@ -204,7 +329,7 @@ def get_text_embeddings( position_ids[batch_idx, seq_idx].unsqueeze(0) if position_ids is not None else None ) - text_embeddings = self.language_model.embedding( + text_embeddings = unwrap_model(self.language_model).embedding( input_ids=input_ids_text, position_ids=position_ids_text ).squeeze( 1 @@ -228,35 +353,187 @@ def forward( attention_mask: Attention mask [batch_size, seq_length] loss_mask: Loss mask [batch_size, seq_length] labels: Labels for training - modality_inputs: Dictionary mapping modality names to encoder inputs. For example: - { - "images": { - "clip_encoder": {"pixel_values": clip_images}, - "vit_encoder": {"images": vit_images} - }, - "audio": { - "whisper_encoder": {"input_features": whisper_features} - } - } + modality_inputs: Dictionary mapping modality names to encoder inputs. Returns: - tuple: Tuple containing model outputs and loss mask + tuple: (output, loss_mask) where output semantics depend on role: + - Encoder-only ranks: Dict[str, Tensor] of encoder outputs + - Language module ranks: language model output (logits or loss) + - No role (all modules colocated): language model output + """ + # Get any tensors passed via set_input_tensor + input_tensors = getattr(self, 'input_tensors', None) + + if self.role is None: + # Original behavior: all modules on all ranks + return self._forward_all_modules( + input_ids, position_ids, attention_mask, + loss_mask, labels, modality_inputs + ) + + if self.role.has_modality_modules and not self.role.has_language_module: + # Encoder-only rank + return self._forward_encoders(modality_inputs, input_tensors), loss_mask + + if self.role.has_language_module and not self.role.has_modality_modules: + # Language-module-only rank + return self._forward_language_module( + input_ids, position_ids, attention_mask, + labels, input_tensors + ), loss_mask + + if self.role.has_modality_modules and self.role.has_language_module: + # Colocated encoders and language module is a configuration error + raise ValueError( + "Invalid configuration: Colocated encoders and language module on the same " + "rank is not supported in multi-module pipeline parallelism. Use separate " + "grids for encoders and language module, or disable multi-module PP by not " + "setting module_to_grid_map." + ) + + raise RuntimeError(f"Rank has no modules assigned in role: {self.role}") + + def _forward_encoders( + self, + modality_inputs: Optional[Dict[str, Dict[str, Any]]], + input_tensors: Optional[Dict[str, torch.Tensor]], + ) -> Dict[str, torch.Tensor]: + """Forward pass for encoder modules on this rank. + + Args: + modality_inputs: Raw inputs for each modality (images, audio, etc.) + input_tensors: Hidden states from previous pipeline stages + + Returns: + Dict mapping encoder names to their output tensors + """ + outputs = {} + + for encoder_name in self.role.modality_module_names: + if encoder_name not in self.modality_submodules: + continue + + submodule = self.modality_submodules[encoder_name] + + # Determine input based on stage position + if self.role.is_first_stage(encoder_name): + # First stage: use raw modality inputs + encoder_input = modality_inputs.get(encoder_name) if modality_inputs else None + if encoder_input is not None: + output = submodule.forward(encoder_inputs=encoder_input) + else: + output = None + else: + # Non-first stage: use hidden states from previous stage + hidden_states = input_tensors.get(encoder_name) if input_tensors else None + if hidden_states is not None: + output = submodule.forward(hidden_states=hidden_states) + else: + output = None + + if output is not None: + outputs[encoder_name] = output + + return outputs + + def _forward_language_module( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + input_tensors: Optional[Dict[str, torch.Tensor]], + ) -> torch.Tensor: + """Forward pass for language module on this rank. + + Args: + input_ids: Token IDs + position_ids: Position IDs + attention_mask: Attention mask + labels: Labels for loss computation + input_tensors: Hidden states or embeddings from previous stage + + Returns: + Language model output (hidden states, logits, or loss depending on stage) + """ + lang_name = self.role.language_module_name + + if self.role.is_first_stage(lang_name): + # First stage: receive encoder embeddings, combine with text, pass to LM + # Build modality embeddings dict from encoder outputs + modality_embeddings = {} + if input_tensors: + for name, tensor in input_tensors.items(): + if name != lang_name: + modality_embeddings[name] = tensor + + # Get text embeddings + text_embeddings = self.get_text_embeddings( + input_ids, position_ids, self.special_token_ids + ) + modality_embeddings["text"] = text_embeddings + + # Combine all embeddings + combined_embeddings = self.align_embeddings_by_token_positions( + modality_embeddings=modality_embeddings, + input_ids=input_ids, + special_token_ids=self.special_token_ids, + ) + + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=combined_embeddings, + labels=labels, + attention_mask=attention_mask, + ) + else: + # Non-first stage: receive hidden states from previous LM stage + hidden_states = input_tensors.get(lang_name) if input_tensors else None + + # Set input tensor on language model for PP + if hidden_states is not None and hasattr(self.language_model, 'set_input_tensor'): + self.language_model.set_input_tensor(hidden_states) + + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=None, + labels=labels, + attention_mask=attention_mask, + ) + + # Key output for non-last stages so schedule can route to next LM stage + if not self.role.is_last_stage(lang_name): + return {lang_name: lm_output} + + return lm_output + + def _forward_all_modules( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + loss_mask: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + modality_inputs: Optional[Dict[str, Dict[str, Any]]], + ): + """Forward pass when all modules are on all ranks (no multi-module PP). + + This is the original behavior, preserved for backward compatibility. """ # 1. Process each modality to get embeddings modality_embeddings = {} for modality_name, submodule in self.modality_submodules.items(): - # Process the modality through its submodule if ( modality_inputs and modality_name in modality_inputs and modality_inputs[modality_name] is not None ): logger.debug(f"Processing {modality_name} modality") - # Get embeddings for this modality embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) if embeddings is not None: - # All embeddings are now in the format [num_tokens, hidden_dim] modality_embeddings[modality_name] = embeddings logger.debug( f"Generated embeddings for {modality_name} with shape {embeddings.shape}" @@ -271,10 +548,10 @@ def forward( # 2. Merge embeddings from different modalities logger.debug(f"Merging embeddings from {len(modality_embeddings)} modalities") combined_embeddings = self.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, # [num_tokens, hidden_dim] for each modality - input_ids=input_ids, # Pass in batch-first format [b, s] + modality_embeddings=modality_embeddings, + input_ids=input_ids, special_token_ids=self.special_token_ids, - ) # [s, b, h] + ) logger.debug(f"Combined embeddings shape: {combined_embeddings.shape}") # 3. Forward pass through language model diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index a1afaa96513..c49a47d2f7c 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -639,12 +639,13 @@ def get_language_model_cp_size(self) -> int: """Get context parallel size for the language model. Returns: - Context parallel size for the language model. - - Raises: - ValueError: If no language model is specified for this collection. + Context parallel size for the language model, or 1 if no language + model is on this rank (e.g., encoder-only ranks in multi-module PP). """ - return self.get_language_model_collection().cp.size() + if not self.has_language_model(): + return 1 # Default CP size for non-LM ranks + cp = self.get_language_model_collection().cp + return cp.size() if cp is not None else 1 def has_language_model(self) -> bool: """Check if this rank has a language model. From b46a157d589e66d49ed1345ee15e5892092c2517 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 14:56:46 -0800 Subject: [PATCH 14/30] Add unit tests for multi-module pipeline parallelism Add comprehensive tests for multi-module PP functionality: - test_mimo_role.py: RankRole and ModuleStageInfo data classes - test_mimo_1f1b_schedule.py: 1F1B schedule with multi-module PP - Update existing tests for stage-aware submodule behavior Tests validate role determination, selective initialization, and stage-aware forward passes for both encoder-only and language-only ranks. Signed-off-by: ykarnati --- .../models/test_mimo_1f1b_schedule.py | 727 ++++++++++++++++++ .../models/test_mimo_audio_submodules.py | 30 + tests/unit_tests/models/test_mimo_model.py | 149 ++++ tests/unit_tests/models/test_mimo_role.py | 54 ++ .../unit_tests/models/test_mimo_submodules.py | 39 + 5 files changed, 999 insertions(+) create mode 100644 tests/unit_tests/models/test_mimo_1f1b_schedule.py create mode 100644 tests/unit_tests/models/test_mimo_role.py diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py new file mode 100644 index 00000000000..b6cd67b6f54 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -0,0 +1,727 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Integration tests for MIMO model with 1F1B pipeline schedule. + +Run with: + torchrun --nproc_per_node=2 tests/unit_tests/models/test_mimo_1f1b_schedule.py +""" + +import logging +from typing import Dict + +import torch +import torch.distributed as dist + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, +) +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import MultiModuleProcessGroupCollection, ProcessGroupCollection +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TERowParallelLinear, + ) +except ImportError: + TEColumnParallelLinear = None + TERowParallelLinear = None + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + """Create a HyperCommGrid with specified parallelism.""" + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, 1], + dim_names=["tp", "cp", "pp", "dp", "ep"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["ep"]) + return grid + + +def get_pg_collection(grid): + """Get ProcessGroupCollection from grid.""" + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + pg_collection.dp = grid.get_pg("dp") + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + return pg_collection + + +def add_embedding_groups(pg_collection): + """Add embedding groups to process group collection.""" + if not pg_collection.pp: + return pg_collection + + pp_ranks = sorted(dist.get_process_group_ranks(pg_collection.pp)) + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + + pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) + embd_pg = dist.new_group(ranks=embd_ranks) + + pg_collection.pos_embd = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + pg_collection.embd = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + + return pg_collection + + +def get_pg_collection_with_embedding_groups(grid): + """Get ProcessGroupCollection with embedding groups.""" + return add_embedding_groups(get_pg_collection(grid)) + + +def is_rank_in_grid(grid): + """Check if current rank is in grid.""" + rank = dist.get_rank() + return grid.rank_offset <= rank < grid.rank_offset + grid.size + + +# ============================================================================ +# Model Spec Helpers +# ============================================================================ + + +def get_language_model_spec( + num_layers: int, + hidden_size: int, + num_attention_heads: int, + vocab_size: int, + seq_len: int, + pg_collection: ProcessGroupCollection, +): + """Get the language model spec.""" + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + pre_process = (pp_rank == 0) + post_process = (pp_rank == pp_size - 1) + + logger.info( + f"[get_language_model_spec] Rank {dist.get_rank()}: PP rank={pp_rank}/{pp_size}, " + f"pre_process={pre_process}, post_process={post_process}" + ) + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='te', + ) + language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": language_layer_spec, + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": pre_process, + "post_process": post_process, + "pg_collection": pg_collection, + }, + ) + return language_model_spec + + +def get_projection_config(hidden_size: int) -> TransformerConfig: + """Return a TransformerConfig for the vision projection MLP.""" + cfg = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=1, + ) + cfg.ffn_hidden_size = hidden_size + cfg.bias_activation_fusion = True + cfg.add_bias_linear = True + cfg.activation_func = torch.nn.functional.gelu + return cfg + + +def get_projection_layer_spec() -> ModuleSpec: + """Layer spec for the vision-projection MLP.""" + if TEColumnParallelLinear is None or TERowParallelLinear is None: + raise RuntimeError("TEColumnParallelLinear and TERowParallelLinear are required") + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) + + +def get_vision_submodules_spec( + num_layers: int, + hidden_size: int, + num_attention_heads: int, + language_hidden_size: int, + pg_collection: ProcessGroupCollection, +): + """Get the submodule spec for the vision modality.""" + vision_layer_spec = get_gpt_layer_with_transformer_engine_spec() + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + + # Calculate pre/post process based on PP rank (same as language model spec) + pp_rank = dist.get_rank(pg_collection.pp) + pre_process = (pp_rank == 0) + post_process = (pp_rank == pp_size - 1) + + logger.info( + f"[get_vision_submodules_spec] Rank {dist.get_rank()}: PP rank={pp_rank}/{pp_size}, " + f"pre_process={pre_process}, post_process={post_process}" + ) + + vision_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + variable_seq_lengths=True, + moe_token_dispatcher_type='alltoall', + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=pp_size, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + vision_encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": vision_layer_spec, + "pg_collection": pg_collection, + "pre_process": pre_process, + "post_process": post_process, + }, + ) + + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": get_projection_config(hidden_size=language_hidden_size), + "submodules": get_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_config.hidden_size, + "tp_group": pg_collection.tp, + }, + ) + + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": {"clip_encoder": vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) + + return vision_submodule_spec + + +def get_mimo_model( + encoder_name: str, + language_module_name: str, + encoder_grid: HyperCommGrid, + llm_grid: HyperCommGrid, + hidden_size: int, + num_layers: int, + vocab_size: int, + seq_len: int, +): + """Create MIMO model with TransformerBlock encoder and GPTModel LLM.""" + language_pg_collection = get_pg_collection_with_embedding_groups(llm_grid) + vision_pg_collection = get_pg_collection_with_embedding_groups(encoder_grid) + + # Always create full specs on all ranks (POC pattern) + language_model_spec = get_language_model_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + vocab_size=vocab_size, + seq_len=seq_len, + pg_collection=language_pg_collection, + ) + + vision_submodule_spec = get_vision_submodules_spec( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + language_hidden_size=hidden_size, + pg_collection=vision_pg_collection, + ) + + module_to_grid_map = { + encoder_name: encoder_grid, + language_module_name: llm_grid, + } + topology = { + encoder_name: [language_module_name], + language_module_name: [], + } + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={encoder_name: vision_submodule_spec}, + special_token_ids={encoder_name: 50257}, + module_to_grid_map=module_to_grid_map, + language_module_key=language_module_name, + ) + + logger.info(f"[Rank {dist.get_rank()}] Creating MimoModel...") + mimo_model = MimoModel(mimo_config) + logger.info(f"[Rank {dist.get_rank()}] MimoModel created successfully") + + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + + # Wrap with DDP + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, + bucket_size=10000, + use_distributed_optimizer=True, + ) + + if mimo_model.language_model is not None: + logger.info(f"[Rank {dist.get_rank()}] Wrapping language_model with DDP") + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg_collection, + ) + + if encoder_name in mimo_model.modality_submodules: + submodule = mimo_model.modality_submodules[encoder_name] + if submodule is not None: + logger.info(f"[Rank {dist.get_rank()}] Wrapping {encoder_name} submodule with DDP") + submodule = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg_collection, + ) + mimo_model.modality_submodules[encoder_name] = submodule + + return mimo_model, module_to_grid_map, topology + + +# ============================================================================ +# Data Iterator +# ============================================================================ + + +class DataIterator: + """Simple data iterator for testing. + + Returns batches matching the POC's MockVLMDataset structure: + - input_ids: [batch_size, seq_length] with image_seq_length image tokens at start + - labels: [batch_size, seq_length] + - loss_mask: [batch_size, seq_length] + - position_ids: [batch_size, seq_length] + - modality_inputs: {modality_name: {encoder_name: {'hidden_states': tensor, 'attention_mask': None}}} + """ + + def __init__(self, hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name, + image_token_id=50257, image_seq_length=None): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.vocab_size = vocab_size + self.encoder_name = encoder_name + self.image_token_id = image_token_id + # Use half the sequence for image tokens by default + self.image_seq_length = image_seq_length or (seq_length // 2) + + def __iter__(self): + return self + + def __next__(self): + # Create encoder input: [image_seq_length, batch_size, hidden_size] + # This matches the number of image tokens in input_ids + encoder_hidden_states = torch.randn( + self.image_seq_length, self.micro_batch_size, self.hidden_size, + device='cuda', dtype=torch.bfloat16 + ) + + # Create input_ids with image tokens at the beginning (like MockVLMDataset) + # Shape: [batch_size, seq_length] + image_tokens = torch.full( + (self.micro_batch_size, self.image_seq_length), + self.image_token_id, + dtype=torch.long, device='cuda' + ) + text_tokens = torch.randint( + 1, self.vocab_size, # Avoid 0 (pad token) + (self.micro_batch_size, self.seq_length - self.image_seq_length), + device='cuda' + ) + input_ids = torch.cat([image_tokens, text_tokens], dim=1) + + # Create labels (copy of input_ids, with image tokens set to -100) + labels = input_ids.clone() + labels[input_ids == self.image_token_id] = -100 + + # Create loss_mask (0 for image tokens, 1 for text tokens) + loss_mask = torch.ones( + self.micro_batch_size, self.seq_length, + device='cuda', dtype=torch.float32 + ) + loss_mask[input_ids == self.image_token_id] = 0.0 + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": torch.arange( + self.seq_length, device='cuda' + ).unsqueeze(0).expand(self.micro_batch_size, -1).clone(), + # modality_inputs structure from POC + "modality_inputs": { + self.encoder_name: { + "clip_encoder": { + 'hidden_states': encoder_hidden_states, + 'attention_mask': None, + } + } + }, + } + + +# ============================================================================ +# Test Runner +# ============================================================================ + + +def run_mimo_1f1b_test( + encoder_tp: int, + encoder_pp: int, + encoder_dp: int, + encoder_offset: int, + llm_tp: int, + llm_pp: int, + llm_dp: int, + llm_offset: int, + hidden_size: int = 256, + num_layers: int = 2, + vocab_size: int = 1000, + seq_length: int = 64, + micro_batch_size: int = 2, + num_microbatches: int = 4, +): + """Run MIMO model through 1F1B schedule and verify.""" + encoder_name = "images" + language_module_name = "language_module" + + logger.info(f"[Rank {dist.get_rank()}] Creating grids...") + encoder_grid = create_hypercomm_grid( + offset=encoder_offset, tp=encoder_tp, cp=1, pp=encoder_pp, dp=encoder_dp + ) + llm_grid = create_hypercomm_grid( + offset=llm_offset, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp + ) + + torch.manual_seed(12345) + + logger.info(f"[Rank {dist.get_rank()}] Creating MIMO model...") + mimo_model, module_to_grid_map, topology = get_mimo_model( + encoder_name=encoder_name, + language_module_name=language_module_name, + encoder_grid=encoder_grid, + llm_grid=llm_grid, + hidden_size=hidden_size, + num_layers=num_layers, + vocab_size=vocab_size, + seq_len=seq_length, + ) + + # Add schedule-related functions to the model's existing config (TransformerConfig) + # Don't replace it with ModelParallelConfig - schedule expects TransformerConfig attributes + def no_sync_func(): + from contextlib import contextmanager, ExitStack + + @contextmanager + def combined_no_sync(): + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + stack.enter_context(submodule.no_sync()) + yield + + return combined_no_sync() + + def finalize_grads_func(*args, **kwargs): + if mimo_model.language_model is not None: + llm_pg = get_pg_collection_with_embedding_groups(llm_grid) + finalize_model_grads([mimo_model.language_model], num_tokens=None, pg_collection=llm_pg) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + encoder_pg = get_pg_collection_with_embedding_groups(encoder_grid) + finalize_model_grads([submodule], num_tokens=None, pg_collection=encoder_pg) + + # Add schedule functions to existing model config + mimo_model.config.no_sync_func = no_sync_func + mimo_model.config.finalize_model_grads_func = finalize_grads_func + mimo_model.config.grad_scale_func = lambda loss: ( + torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + if isinstance(loss, (int, float)) else loss + ) + + logger.info(f"[Rank {dist.get_rank()}] Creating communicator...") + communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + # Create data iterator on: + # - Encoder's first PP stage (needs modality_inputs) + # - LLM's first PP stage (needs input_ids for embeddings) + # - LLM's last PP stage (needs labels for loss) + data_iterator = None + + encoder_needs_data = ( + is_rank_in_grid(encoder_grid) and + is_pp_first_stage(encoder_grid.get_pg("pp")) + ) + llm_needs_data = ( + is_rank_in_grid(llm_grid) and + (is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp"))) + ) + + if encoder_needs_data or llm_needs_data: + logger.info(f"[Rank {dist.get_rank()}] Creating data iterator (encoder={encoder_needs_data}, llm={llm_needs_data})") + data_iterator = DataIterator(hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name) + + # Build MultiModuleProcessGroupCollection + # Only include pg_collections for modules this rank participates in + module_pgs = {} + if is_rank_in_grid(encoder_grid): + module_pgs[encoder_name] = get_pg_collection_with_embedding_groups(encoder_grid) + if is_rank_in_grid(llm_grid): + module_pgs[language_module_name] = get_pg_collection_with_embedding_groups(llm_grid) + + # Set language_model_module_name only if this rank participates in LLM + lang_module_name = language_module_name if is_rank_in_grid(llm_grid) else None + + pg_collection = MultiModuleProcessGroupCollection( + module_pgs=module_pgs, + language_model_module_name=lang_module_name, + ) + + def step_func(data_iterator, model): + from functools import partial + + def loss_func(loss_mask, output_tensor): + """Loss function matching POC pattern.""" + if output_tensor is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + + # Handle dict output (from encoder or intermediate LLM stages) + if isinstance(output_tensor, dict): + if language_module_name in output_tensor: + output = output_tensor[language_module_name] + else: + output = list(output_tensor.values())[0] if output_tensor else None + else: + output = output_tensor + + if output is None: + return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} + + loss = output.float().sum() + return loss, {'loss_reduced': loss} + + batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} + # MimoModel.forward() returns (output_tensor, loss_mask) tuple + output_tensor, loss_mask = model(**batch) + # Return only output_tensor, bind loss_mask to loss_func via partial + return output_tensor, partial(loss_func, loss_mask) + + logger.info(f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches...") + losses = schedule.forward_backward_pipelining_without_interleaving( + forward_step_func=step_func, + data_iterator=data_iterator, + model=[mimo_model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=communicator, + pg_collection=pg_collection, + ) + + # Verify results on last LLM stage + if is_rank_in_grid(llm_grid): + if is_pp_last_stage(llm_grid.get_pg("pp")): + logger.info(f"[Rank {dist.get_rank()}] Last LLM stage - got {len(losses)} losses") + assert len(losses) > 0, "Expected losses on last LLM stage" + for loss_dict in losses: + assert 'loss_reduced' in loss_dict, "Expected 'loss_reduced' in loss dict" + + logger.info(f"[Rank {dist.get_rank()}] Test completed successfully!") + return losses + + +def get_test_configs(): + """Get predefined test configurations for different GPU counts. + + Returns: + Dict mapping world_size to list of test configurations. + """ + return { + # 2 GPUs: Encoder PP=1, LLM PP=1 (baseline) + 2: [ + { + "name": "baseline_2gpu", + "encoder_tp": 1, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, + "llm_tp": 1, "llm_pp": 1, "llm_dp": 1, "llm_offset": 1, + "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, + "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, + }, + ], + # 4 GPUs: Encoder PP=1, LLM PP=3 (tests keyed output fix) + 4: [ + { + "name": "lm_pp3_4gpu", + "encoder_tp": 1, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, + "llm_tp": 1, "llm_pp": 3, "llm_dp": 1, "llm_offset": 1, + "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, + "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, + }, + ], + # 8 GPUs: Multiple configurations + 8: [ + # Config 1: Encoder TP=2 PP=1, LLM TP=2 PP=3 (heterogeneous) + # Encoder: 2 ranks (0-1), LLM: 6 ranks (2-7) + # num_layers must be divisible by pp, so use 3 + { + "name": "encoder_tp2_llm_tp2_pp3_8gpu", + "encoder_tp": 2, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, + "llm_tp": 2, "llm_pp": 3, "llm_dp": 1, "llm_offset": 2, + "hidden_size": 256, "num_layers": 3, "vocab_size": 1000, + "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, + }, + # Config 2: Encoder PP=2, LLM PP=2 with TP=2 each + # Encoder: 4 ranks (0-3), LLM: 4 ranks (4-7) + { + "name": "full_pp_8gpu", + "encoder_tp": 2, "encoder_pp": 2, "encoder_dp": 1, "encoder_offset": 0, + "llm_tp": 2, "llm_pp": 2, "llm_dp": 1, "llm_offset": 4, + "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, + "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, + }, + ], + } + + +def main(): + """Main entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="MIMO 1F1B Schedule Test") + parser.add_argument("--config", type=str, default=None, + help="Specific config name to run (e.g., 'baseline_2gpu')") + parser.add_argument("--list-configs", action="store_true", + help="List available configurations and exit") + args = parser.parse_args() + + # List configs if requested + if args.list_configs: + configs = get_test_configs() + print("Available configurations:") + for world_size, config_list in configs.items(): + print(f"\n {world_size} GPUs:") + for cfg in config_list: + print(f" - {cfg['name']}") + return + + # Initialize distributed + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + logger.info(f"Rank {rank}/{world_size} initialized") + + configs = get_test_configs() + + if world_size not in configs: + logger.error(f"No configurations for world_size={world_size}. Available: {list(configs.keys())}") + dist.destroy_process_group() + return + + # Filter configs if specific one requested + test_configs = configs[world_size] + if args.config: + test_configs = [c for c in test_configs if c["name"] == args.config] + if not test_configs: + logger.error(f"Config '{args.config}' not found for {world_size} GPUs") + dist.destroy_process_group() + return + + # Run all matching configs + for config in test_configs: + name = config.pop("name") + logger.info(f"Running test: {name}") + try: + run_mimo_1f1b_test(**config) + logger.info(f"Test {name} PASSED") + except Exception as e: + logger.error(f"Test {name} FAILED: {e}") + raise + finally: + config["name"] = name # Restore for potential reuse + + dist.destroy_process_group() + logger.info("All tests completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/tests/unit_tests/models/test_mimo_audio_submodules.py b/tests/unit_tests/models/test_mimo_audio_submodules.py index 0f3865d940f..f9a18838f60 100644 --- a/tests/unit_tests/models/test_mimo_audio_submodules.py +++ b/tests/unit_tests/models/test_mimo_audio_submodules.py @@ -394,3 +394,33 @@ def test_multiple_audio_encoders(self, model_name, batch_size): print( f"Model {model_name} (d_model={self.d_model}) successfully processed audio and projected to dimension 768" ) + + +class TestAudioSubmoduleStageAware: + """Tests for stage-aware forward in AudioModalitySubmodules.""" + + def test_stage_aware_forward(self): + """Test stage-aware forward: hidden_states input and projection skipping.""" + import torch.nn as nn + + hidden_size = 64 + projection_size = 128 + hidden_states = torch.randn(10, hidden_size) + + # Non-first stage uses hidden_states, last stage projects + submodule_last = AudioModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=True, + ) + output = submodule_last.forward(hidden_states=hidden_states) + assert output.shape == (10, projection_size) # Projected + + # Non-last stage skips projection + submodule_mid = AudioModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=False, + ) + output = submodule_mid.forward(hidden_states=hidden_states) + assert output.shape == (10, hidden_size) # Not projected diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index f786f118c68..e664edc9dee 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -455,3 +455,152 @@ def test_state_dict(self): # Test checkpoint state dict checkpoint_dict = mimo_model.state_dict_for_save_checkpoint() assert len(checkpoint_dict) > 0 + + +class MockProcessGroup: + """Mock process group for testing.""" + def __init__(self, rank, world_size): + self._rank = rank + self._size = world_size + + def rank(self): + return self._rank + + def size(self): + return self._size + + +class MockGrid: + """Mock grid with HyperCommGrid-compatible interface.""" + def __init__(self, rank_offset=0, size=1, dim_names=None, pp_rank=0, pp_size=1): + self.rank_offset = rank_offset + self.size = size + self.dim_names = dim_names or [] + self._pp_group = MockProcessGroup(pp_rank, pp_size) + + def get_pg(self, dims): + if dims == "pp": + return self._pp_group + raise KeyError(f"Process group for {dims} not found") + + +class TestMimoModelNonColocated: + """Tests for non-colocated multi-module pipeline parallelism.""" + + def setup_method(self, method): + try: + Utils.initialize_model_parallel(1, 1) + except Exception: + pass + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.hidden_size = 64 + self.vocab_size = 48000 + self.seq_len = 256 + self.batch_size = 2 + + def teardown_method(self, method): + try: + Utils.destroy_model_parallel() + except Exception: + pass + + def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, pp_size=1): + """Helper to create MimoModelConfig with mock grids.""" + language_model_spec = get_language_model_spec(self.hidden_size, self.vocab_size, self.seq_len) + vision_submodule_spec = get_vision_submodules_spec(self.hidden_size, 224, 224, 16) + + encoder_offset = 0 if encoder_in_grid else 10 # rank 0 in grid if offset=0 + language_offset = 0 if language_in_grid else 10 + + return MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={ + "images": MockGrid(rank_offset=encoder_offset, size=1, dim_names=["pp"] if pp_size > 1 else [], pp_rank=pp_rank, pp_size=pp_size), + "language": MockGrid(rank_offset=language_offset, size=1, dim_names=["pp"] if pp_size > 1 else [], pp_rank=pp_rank, pp_size=pp_size), + }, + language_module_key="language", + ) + + def test_grid_validation_rejects_mismatched_keys(self): + """Test validation fails when grid_map keys don't match expected modules.""" + language_model_spec = get_language_model_spec(self.hidden_size, self.vocab_size, self.seq_len) + vision_submodule_spec = get_vision_submodules_spec(self.hidden_size, 224, 224, 16) + + # Missing 'images' in grid_map + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={"language": MockGrid()}, + language_module_key="language", + ) + + with pytest.raises(ValueError, match="module_to_grid_map keys must match"): + MimoModel(mimo_config) + + def test_role_determination(self): + """Test role correctly identifies modules and stage positions.""" + # No grid map = no role + model_no_grid = get_vlm_mimo_model(self.hidden_size, self.vocab_size, self.seq_len, 224, 224, 16, {"images": 50257}) + assert model_no_grid.role is None + + # Encoder-only rank (language grid excludes rank 0) + model_encoder = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + assert model_encoder.role.has_modality_modules is True + assert model_encoder.role.has_language_module is False + + # Language-only rank (encoder grid excludes rank 0) + model_language = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + assert model_language.role.has_modality_modules is False + assert model_language.role.has_language_module is True + + # Stage info with PP + model_pp = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=True, pp_rank=1, pp_size=3)) + assert model_pp.role.is_first_stage("images") is False + assert model_pp.role.is_last_stage("images") is False + + def test_selective_init_encoder_only(self): + """Test encoder-only rank initializes encoder but not language model.""" + model = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + assert "images" in model.modality_submodules + assert model.language_model is None + + def test_selective_init_language_only(self): + """Test language-only rank initializes language model but not encoder.""" + model = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + assert "images" not in model.modality_submodules + assert model.language_model is not None + + def test_forward_encoder_only(self): + """Test encoder-only forward returns dict of embeddings.""" + model = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) + model = model.to(self.device) + + images = torch.rand(2, 3, 224, 224, device=self.device) + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device) + + outputs, _ = model(input_ids=input_ids, modality_inputs={"images": {"clip_encoder": {"x": images}}}) + + assert isinstance(outputs, dict) + assert "images" in outputs + + def test_forward_language_only(self): + """Test language-only forward returns tensor.""" + model = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) + model = model.to(self.device) + + img_seq_len = (224 // 16) * (224 // 16) + 1 + input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device) + input_ids[:, 5:5 + img_seq_len] = 50257 # image tokens + position_ids = torch.arange(self.seq_len, device=self.device).unsqueeze(0).expand(self.batch_size, -1) + + # Simulate encoder output from previous stage + encoder_embeddings = torch.randn(self.batch_size * img_seq_len, self.hidden_size, device=self.device) + model.set_input_tensor({"images": encoder_embeddings}) + + outputs, _ = model(input_ids=input_ids, position_ids=position_ids, modality_inputs=None) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) diff --git a/tests/unit_tests/models/test_mimo_role.py b/tests/unit_tests/models/test_mimo_role.py new file mode 100644 index 00000000000..28f2c5cae54 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_role.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for MIMO role data classes.""" + +import pytest + +from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole + + +class TestMimoRole: + """Tests for ModuleStageInfo and RankRole dataclasses.""" + + def test_module_stage_info(self): + """Test ModuleStageInfo creation and attributes.""" + first = ModuleStageInfo(is_first_stage=True, is_last_stage=False) + last = ModuleStageInfo(is_first_stage=False, is_last_stage=True) + only = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + + assert (first.is_first_stage, first.is_last_stage) == (True, False) + assert (last.is_first_stage, last.is_last_stage) == (False, True) + assert (only.is_first_stage, only.is_last_stage) == (True, True) + + def test_rank_role(self): + """Test RankRole properties and methods.""" + # Encoder-only role + encoder_role = RankRole( + modules={"vision": ModuleStageInfo(True, False)}, + language_module_name="language", + ) + assert encoder_role.has_modality_modules is True + assert encoder_role.has_language_module is False + assert encoder_role.modality_module_names == ["vision"] + + # Language-only role + lang_role = RankRole( + modules={"language": ModuleStageInfo(True, True)}, + language_module_name="language", + ) + assert lang_role.has_modality_modules is False + assert lang_role.has_language_module is True + + # Mixed role with stage checks + mixed = RankRole( + modules={ + "vision": ModuleStageInfo(is_first_stage=True, is_last_stage=False), + "language": ModuleStageInfo(is_first_stage=False, is_last_stage=True), + }, + language_module_name="language", + ) + assert mixed.is_first_stage("vision") is True + assert mixed.is_last_stage("vision") is False + assert mixed.is_first_stage("language") is False + assert mixed.is_last_stage("language") is True + assert mixed.is_first_stage("nonexistent") is False diff --git a/tests/unit_tests/models/test_mimo_submodules.py b/tests/unit_tests/models/test_mimo_submodules.py index 6111394cc13..5f8de29cc0f 100644 --- a/tests/unit_tests/models/test_mimo_submodules.py +++ b/tests/unit_tests/models/test_mimo_submodules.py @@ -303,3 +303,42 @@ def test_empty_data_batch(self): # Test forward pass output = self.vision_submodule(data_batch) assert output is None + + +@pytest.mark.experimental +class TestVisionSubmoduleStageAware: + """Tests for stage-aware forward in VisionModalitySubmodules.""" + + def test_stage_aware_forward(self): + """Test stage-aware forward: hidden_states input and projection skipping.""" + hidden_size = 64 + projection_size = 128 + hidden_states = torch.randn(10, hidden_size) + + # Default: first and last stage + submodule_default = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)] + ) + assert submodule_default.is_first_stage is True + assert submodule_default.is_last_stage is True + + # Non-first stage uses hidden_states, last stage projects + submodule_last = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=True, + ) + output = submodule_last.forward(hidden_states=hidden_states) + assert output.shape == (10, projection_size) # Projected + + # Non-last stage skips projection + submodule_mid = VisionModalitySubmodules( + input_projections=[nn.Linear(hidden_size, projection_size)], + is_first_stage=False, + is_last_stage=False, + ) + output = submodule_mid.forward(hidden_states=hidden_states) + assert output.shape == (10, hidden_size) # Not projected + + # No input returns None + assert submodule_mid.forward(hidden_states=None) is None From 7da19e1447db9c19bbd698a14db6b098e56c5c4c Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 15:01:37 -0800 Subject: [PATCH 15/30] Add .worktrees/ to gitignore Co-Authored-By: Claude Opus 4.5 --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index a9ce4aa0a93..5556d1d5a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ runs/ # Sphinx documentation docs/_build -docs/apidocs \ No newline at end of file +docs/apidocs + +# Git worktrees +.worktrees/ \ No newline at end of file From 62dfb8920e53aec4daef3b7e6a013a96e5832450 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 08:20:43 -0700 Subject: [PATCH 16/30] Simplify MIMO model and consolidate submodule logic - base.py: Remove redundant conditionals in _initialize_submodules, simplify forward() dispatch with guard-first pattern, collapse _forward_encoders to single-expression conditionals, return None from _determine_role when rank is in no grid - submodules/base.py: Promote encode, combine_embeddings, project_embeddings, and forward from abstract to concrete methods, fix missing f-prefix in error message, fix project_embeddings to always combine before projecting - submodules/vision.py, audio.py: Remove duplicate implementations, keep only __init__ (with projection assertions) and decode - config/role.py: Add __post_init__ validation for language_module_name - from_spec docstring: Document is_first_stage controls output projections Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/config/role.py | 8 + megatron/core/models/mimo/model/base.py | 49 +++--- megatron/core/models/mimo/submodules/audio.py | 135 +-------------- megatron/core/models/mimo/submodules/base.py | 110 +++++++++---- .../core/models/mimo/submodules/vision.py | 155 +----------------- 5 files changed, 127 insertions(+), 330 deletions(-) diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 6a4650c45c4..677218f4c24 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -37,6 +37,14 @@ class RankRole: modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) language_module_name: Optional[str] = None + def __post_init__(self): + """Validate that language_module_name is set when modules is non-empty.""" + if self.modules and self.language_module_name is None: + raise ValueError( + "language_module_name must be set when modules is non-empty. " + f"Got modules={list(self.modules.keys())} with language_module_name=None." + ) + @property def has_modality_modules(self) -> bool: """Return True if this rank participates in any modality (non-language) module.""" diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 411dbe049cd..8df257b7c5e 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -174,7 +174,7 @@ def _initialize_submodules(self) -> None: # Determine stage info for this module is_first_stage = True is_last_stage = True - if self.role is not None and modality_name in self.role.modules: + if self.role is not None: stage_info = self.role.modules[modality_name] is_first_stage = stage_info.is_first_stage is_last_stage = stage_info.is_last_stage @@ -277,6 +277,9 @@ def _determine_role(self) -> Optional[RankRole]: ) modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) + if not modules: + return None + return RankRole(modules=modules, language_module_name=self.mimo_config.language_module_key) def set_input_tensor(self, input_tensor): @@ -407,12 +410,19 @@ def forward( packing_kwargs, ) - if self.role.has_modality_modules and not self.role.has_language_module: - # Encoder-only rank + # Guard: colocated encoders + language module is not supported + if self.role.has_modality_modules and self.role.has_language_module: + raise ValueError( + "Invalid configuration: Colocated encoders and language module on the same " + "rank is not supported in multi-module pipeline parallelism. Use separate " + "grids for encoders and language module, or disable multi-module PP by not " + "setting module_to_grid_map." + ) + + if self.role.has_modality_modules: return self._forward_encoders(modality_inputs, input_tensors), loss_mask - if self.role.has_language_module and not self.role.has_modality_modules: - # Language-module-only rank + if self.role.has_language_module: return ( self._forward_language_module( input_ids, position_ids, attention_mask, labels, input_tensors @@ -420,15 +430,6 @@ def forward( loss_mask, ) - if self.role.has_modality_modules and self.role.has_language_module: - # Colocated encoders and language module is a configuration error - raise ValueError( - "Invalid configuration: Colocated encoders and language module on the same " - "rank is not supported in multi-module pipeline parallelism. Use separate " - "grids for encoders and language module, or disable multi-module PP by not " - "setting module_to_grid_map." - ) - raise RuntimeError(f"Rank has no modules assigned in role: {self.role}") def _forward_encoders( @@ -455,19 +456,19 @@ def _forward_encoders( # Determine input based on stage position if self.role.is_first_stage(encoder_name): - # First stage: use raw modality inputs encoder_input = modality_inputs.get(encoder_name) if modality_inputs else None - if encoder_input is not None: - output = submodule.forward(encoder_inputs=encoder_input) - else: - output = None + output = ( + submodule.forward(encoder_inputs=encoder_input) + if encoder_input is not None + else None + ) else: - # Non-first stage: use hidden states from previous stage hidden_states = input_tensors.get(encoder_name) if input_tensors else None - if hidden_states is not None: - output = submodule.forward(hidden_states=hidden_states) - else: - output = None + output = ( + submodule.forward(hidden_states=hidden_states) + if hidden_states is not None + else None + ) if output is not None: outputs[encoder_name] = output diff --git a/megatron/core/models/mimo/submodules/audio.py b/megatron/core/models/mimo/submodules/audio.py index 5b4910a2989..6db2782d82f 100644 --- a/megatron/core/models/mimo/submodules/audio.py +++ b/megatron/core/models/mimo/submodules/audio.py @@ -1,16 +1,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import logging -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -import torch import torch.nn as nn from megatron.core.models.mimo.submodules.base import ModalitySubmodules -# Initialize logger -logger = logging.getLogger(__name__) - class AudioModalitySubmodules(ModalitySubmodules): """Audio modality submodules for encoding, decoding, and projecting audio data.""" @@ -32,7 +27,13 @@ def __init__( output_projections: List of output projection modules **kwargs: Additional keyword arguments """ - super().__init__(encoders, decoders, input_projections, output_projections, **kwargs) + super().__init__( + encoders=encoders, + decoders=decoders, + input_projections=input_projections, + output_projections=output_projections, + **kwargs, + ) if self.input_projections: assert ( @@ -44,124 +45,6 @@ def __init__( len(self.output_projections) <= 1 ), "AudioModalitySubmodules currently supports only one output projection" - def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: - """Encode audio data into a sequence of embeddings. - - Args: - encoders_data_batch: Dictionary containing encoder-specific inputs. - Keys should match encoder names in self.encoders. - Each encoder receives its own specific inputs. - - Returns: - List of encoded audio embeddings, one from each encoder. - Each embedding is a flattened tensor of shape [total_tokens, hidden_dim] - - Raises: - ValueError: If no data is provided for any encoder or if there's a parameter mismatch. - """ - if not encoders_data_batch: - return [] - - embeddings = [] - - for name, encoder in self.encoders.items(): - if name not in encoders_data_batch: - raise ValueError(f"No inputs found for encoder '{name}'") - - encoder_inputs = encoders_data_batch[name] - - # Process inputs through the encoder - encoder_outputs = encoder(**encoder_inputs) - logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") - if encoder_outputs.ndim == 3: - # its b,s,h -> we need to flatten it to b*s,h - encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) - embeddings.append(encoder_outputs) - elif encoder_outputs.ndim == 2: - # its b*s,h -> encoder already returned the flattened output - embeddings.append(encoder_outputs) - else: - raise ValueError( - f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" - "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" - ) - return embeddings - - def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: + def decode(self, embeddings, data_batch: Dict): """Decode embeddings into audio data.""" raise NotImplementedError("Audio decoding not implemented yet") - - def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine embeddings from different encoders.""" - if not embeddings: - raise ValueError("Cannot combine empty list of embeddings") - - if len(embeddings) == 1: - return embeddings[0] - - # Concatenate along sequence dimension - # each embedding is [total_tokens, hidden_dim] - combined = torch.cat(embeddings, dim=0) - logger.debug(f"Combined audio embeddings shape: {combined.shape}") - return combined - - def project_embeddings( - self, embeddings: List[torch.Tensor], is_input: bool = True - ) -> torch.Tensor: - """Project embeddings to the language model dimension space.""" - - if is_input: - embeddings = self.combine_embeddings(embeddings) - - # Get the appropriate projections - projections = self.input_projections if is_input else self.output_projections - - # Apply projection if available - if projections: - # We've asserted in __init__ that there's only one projection - projection = projections[0] - projected = projection(embeddings) - logger.debug(f"Post-projection audio embeddings shape: {projected.shape}") - return projected - - return embeddings - - def forward( - self, - encoder_inputs: Optional[Dict[str, Any]] = None, - hidden_states: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - """Process audio data through encoding and projection. - - Args: - encoder_inputs: Dictionary where keys match encoder names in self.encoders - and values are dictionaries of encoder-specific parameters. - Used when is_first_stage=True. - hidden_states: Hidden states from previous pipeline stage. - Used when is_first_stage=False. - - Returns: - - If is_last_stage: projected embeddings ready for language model - - If not is_last_stage: hidden states for next pipeline stage - - None if no valid input provided - """ - # Determine input based on stage position - if self.is_first_stage: - if encoder_inputs is None: - return None - # Encode the audio - embeddings = self.encode(encoder_inputs) - if not embeddings: - return None - combined = self.combine_embeddings(embeddings) - else: - if hidden_states is None: - return None - # Use hidden states from previous stage - combined = hidden_states - - # Project only if last stage - if self.is_last_stage: - return self.project_embeddings([combined], is_input=True) - else: - return combined diff --git a/megatron/core/models/mimo/submodules/base.py b/megatron/core/models/mimo/submodules/base.py index a52c9cf1cc6..58f61f81d3c 100644 --- a/megatron/core/models/mimo/submodules/base.py +++ b/megatron/core/models/mimo/submodules/base.py @@ -85,19 +85,17 @@ def is_last_stage(self) -> bool: @classmethod def from_spec( - cls, - module_spec: ModuleSpec, - is_first_stage: bool = True, - is_last_stage: bool = True, + cls, module_spec: ModuleSpec, is_first_stage: bool = True, is_last_stage: bool = True ) -> 'ModalitySubmodules': """Create a modality submodule from ModuleSpec configuration. Args: module_spec (ModuleSpec): The module specification for this modality submodule is_first_stage (bool): Whether this is the first pipeline stage for this module. - Controls encoder initialization. Defaults to True. + Controls encoder initialization and output projection initialization + (output projections only built on first stage). Defaults to True. is_last_stage (bool): Whether this is the last pipeline stage for this module. - Controls input projection initialization (only needed on last stage). + Controls input projection initialization (only built on last stage). Defaults to True. Returns: @@ -137,9 +135,7 @@ def from_spec( projection = build_module(proj_spec) input_projections.append(projection) elif 'input_projections' in submodules: - logger.debug( - f"Skipping {cls.__name__} input projections (not last stage)" - ) + logger.debug(f"Skipping {cls.__name__} input projections (not last stage)") # Build output projections only on first stage # (projection happens before decoding, after receiving from language model) @@ -152,9 +148,7 @@ def from_spec( projection = build_module(proj_spec) output_projections.append(projection) elif 'output_projections' in submodules: - logger.debug( - f"Skipping {cls.__name__} output projections (not first stage)" - ) + logger.debug(f"Skipping {cls.__name__} output projections (not first stage)") # Pass any additional parameters from the params dictionary additional_params = params.copy() @@ -173,31 +167,61 @@ def from_spec( **additional_params, ) - @abstractmethod def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine multiple embeddings from different encoders. + """Combine multiple embeddings from different encoders by concatenation. Args: embeddings (List[torch.Tensor]): - List of embeddings to combine + List of embeddings to combine. Each is [total_tokens, hidden_dim]. Returns: torch.Tensor: Combined embedding tensor """ - pass + if not embeddings: + raise ValueError("Cannot combine empty list of embeddings") - @abstractmethod - def encode(self, data_batch: Dict) -> List[torch.Tensor]: + if len(embeddings) == 1: + return embeddings[0] + + combined = torch.cat(embeddings, dim=0) + logger.debug(f"Combined embeddings shape after concatenation: {combined.shape}") + return combined + + def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: """Encode data batch into a list of tensors. Args: - data_batch (Dict): - Dictionary containing input data + encoders_data_batch (Dict): + Dictionary containing encoder-specific inputs. + Keys should match encoder names in self.encoders. Returns: - List[torch.Tensor]: List of encoded embeddings + List[torch.Tensor]: List of encoded embeddings, each [total_tokens, hidden_dim] """ - pass + if not encoders_data_batch: + return [] + + embeddings = [] + + for name, encoder in self.encoders.items(): + if name not in encoders_data_batch: + raise ValueError(f"No inputs found for encoder '{name}'") + + encoder_inputs = encoders_data_batch[name] + encoder_outputs = encoder(**encoder_inputs) + logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") + + if encoder_outputs.ndim == 3: + encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) + elif encoder_outputs.ndim != 2: + raise ValueError( + f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported. " + f"Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" + ) + + embeddings.append(encoder_outputs) + + return embeddings @abstractmethod def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: @@ -214,11 +238,10 @@ def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: """ pass - @abstractmethod def project_embeddings( self, embeddings: List[torch.Tensor], is_input: bool = True ) -> Optional[torch.Tensor]: - """Project embeddings into a tensor. + """Project embeddings using input or output projections. Args: embeddings (List[torch.Tensor]): @@ -229,18 +252,49 @@ def project_embeddings( Returns: Optional[torch.Tensor]: Projected embeddings or None """ - pass + combined = self.combine_embeddings(embeddings) - @abstractmethod - def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + projections = self.input_projections if is_input else self.output_projections + + if projections: + projection = projections[0] + projected = projection(combined) + logger.debug(f"Post-projection embeddings shape: {projected.shape}") + return projected + + return combined + + def forward( + self, + encoder_inputs: Optional[Dict[str, Any]] = None, + hidden_states: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: """Process data for this modality through encoding and projection. Args: encoder_inputs (Dict[str, Any]): Dictionary containing encoder-specific inputs. Keys should match encoder names. + Used when is_first_stage=True. + hidden_states (Optional[torch.Tensor]): + Hidden states from previous pipeline stage. Used when is_first_stage=False. Returns: Optional[torch.Tensor]: Processed and projected embeddings tensor, or None if no embeddings were produced. """ - pass + if self.is_first_stage: + if encoder_inputs is None: + return None + embeddings = self.encode(encoder_inputs) + if not embeddings: + return None + combined = self.combine_embeddings(embeddings) + else: + if hidden_states is None: + return None + combined = hidden_states + + if self.is_last_stage: + return self.project_embeddings([combined], is_input=True) + + return combined diff --git a/megatron/core/models/mimo/submodules/vision.py b/megatron/core/models/mimo/submodules/vision.py index dfd22e2b392..0bb1a45e013 100644 --- a/megatron/core/models/mimo/submodules/vision.py +++ b/megatron/core/models/mimo/submodules/vision.py @@ -1,16 +1,11 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import logging -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -import torch import torch.nn as nn from megatron.core.models.mimo.submodules.base import ModalitySubmodules -# Initialize logger -logger = logging.getLogger(__name__) - class VisionModalitySubmodules(ModalitySubmodules): """Vision modality submodules for encoding, decoding, and projecting image data. @@ -53,150 +48,6 @@ def __init__( len(self.output_projections) <= 1 ), "VisionModalitySubmodules currently supports only one output projection" - def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: - """Encode image data batch into a list of tensors. - - Args: - encoders_data_batch: Dictionary containing encoder-specific inputs. - Keys should match encoder names in self.encoders. - Each encoder receives its own specific inputs. - - Returns: - List of encoded image embeddings, one from each encoder. - Each embedding is a flattened tensor of shape [total_tokens, hidden_dim] - - Raises: - ValueError: If no data is provided for any encoder or if there's a parameter mismatch. - """ - if not encoders_data_batch: - return [] - - embeddings = [] - - for name, encoder in self.encoders.items(): - if name not in encoders_data_batch: - raise ValueError(f"No inputs found for encoder '{name}'") - - encoder_inputs = encoders_data_batch[name] - - # Process inputs through the encoder - encoder_outputs = encoder(**encoder_inputs) - logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") - if encoder_outputs.ndim == 3: - # its b,s,h -> we need to flatten it to b*s,h - encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) - embeddings.append(encoder_outputs) - elif encoder_outputs.ndim == 2: - # its b*s,h -> encoder already returned the flattened output - embeddings.append(encoder_outputs) - else: - raise ValueError( - f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" - "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" - ) - - return embeddings - - def decode(self, embeddings: torch.Tensor, data_batch: Dict) -> torch.Tensor: - """Decode embeddings into image tensors. - - Args: - embeddings: Tensor of embeddings to decode. - data_batch: Dictionary containing additional data for decoding. - - Returns: - Tensor containing generated images. - """ - + def decode(self, embeddings, data_batch: Dict): + """Decode embeddings into image tensors.""" raise NotImplementedError("No decoders support yet") - - def combine_embeddings(self, embeddings: List[torch.Tensor]) -> torch.Tensor: - """Combine multiple embeddings from different encoders by concatenation. - - This method is used for combining encoder outputs before input projection. - - Args: - embeddings: List of embeddings to combine - - Returns: - Combined embedding tensor - """ - if not embeddings: - raise ValueError("Cannot combine empty list of embeddings") - - if len(embeddings) == 1: - return embeddings[0] - - # each embedding is [total_tokens, hidden_dim] - # Make this configurable in the future - combined = torch.cat(embeddings, dim=0) - logger.debug(f"Combined embeddings shape after concatenation: {combined.shape}") - return combined - - def project_embeddings( - self, embeddings: List[torch.Tensor], is_input: bool = True - ) -> torch.Tensor: - """Project image embeddings using input or output projections. - - Args: - embeddings: List of image embeddings to project - is_input: If True, use input projections, otherwise use output projections - - Returns: - Projected image embeddings or None if no embeddings - """ - if is_input: - embeddings = self.combine_embeddings(embeddings) - - # Get the appropriate projection (input or output) - projections = self.input_projections if is_input else self.output_projections - - # Apply projection if available - if projections: - # We've asserted in __init__ that there's only one projection - projection = projections[0] - projected = projection(embeddings) - logger.debug(f"Post-projection embeddings shape: {projected.shape}") - return projected - - return embeddings - - def forward( - self, - encoder_inputs: Optional[Dict[str, Any]] = None, - hidden_states: Optional[torch.Tensor] = None, - ) -> Optional[torch.Tensor]: - """Process image data through encoding and projection. - - Args: - encoder_inputs: Dictionary where keys match encoder names in self.encoders - and values are dictionaries of encoder-specific parameters. - Used when is_first_stage=True. - hidden_states: Hidden states from previous pipeline stage. - Used when is_first_stage=False. - - Returns: - - If is_last_stage: projected embeddings ready for language model - - If not is_last_stage: hidden states for next pipeline stage - - None if no valid input provided - """ - # Determine input based on stage position - if self.is_first_stage: - if encoder_inputs is None: - return None - # Encode the images - embeddings = self.encode(encoder_inputs) - if not embeddings: - return None - combined = self.combine_embeddings(embeddings) - else: - if hidden_states is None: - return None - # Use hidden states from previous stage - combined = hidden_states - - # Project only if last stage - if self.is_last_stage: - return self.project_embeddings([combined], is_input=True) - else: - return combined From 99800d2882b8777b56333431606179d5d3c720a7 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 09:55:39 -0700 Subject: [PATCH 17/30] Simplify and deduplicate MIMO model tests - Add _make_vlm/_make_avlm/_make_input_ids/_make_position_ids helpers to eliminate repeated 7-arg factory calls and tensor construction - Move device to setup_method, remove 7 duplicate torch.device() lines - Delete dead module-level AudioEncoderWrapper (duplicate of inner class) - Simplify test_state_dict to any() one-liners - Remove redundant assert-not-None before shape checks - Fix hardcoded batch_size=2 to use self.batch_size - Remove test-internal setup assertions that can never fail - Add img_h/img_w/patch_dim attrs to TestMimoModelNonColocated Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/unit_tests/models/test_mimo_model.py | 444 ++++++--------------- 1 file changed, 127 insertions(+), 317 deletions(-) diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 5fc086f05a0..e29dad10edf 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. ''' -WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_model.py +WORLD_SIZE=1 LOCAL_RANK=0 python -m pytest tests/unit_tests/models/test_mimo_model.py ''' import math @@ -31,13 +31,11 @@ class AudioEncoderWrapper(torch.nn.Module): """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - def __init__(self, config): + def __init__(self, **kwargs): super().__init__() - # Use a local Whisper model (tiny config) to avoid checkpoint download self.encoder = WhisperModel(WhisperConfig()).encoder def forward(self, input_features): - # Process through encoder and extract last_hidden_state with torch.no_grad(): return self.encoder(input_features).last_hidden_state @@ -60,7 +58,6 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - # Create vision projection spec vision_projection_spec = ModuleSpec( module=nn.Linear, params={ @@ -69,8 +66,7 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - # Create vision modality spec - vision_submodule_spec = ModuleSpec( + return ModuleSpec( module=VisionModalitySubmodules, submodules={ "encoders": {"clip_encoder": vision_encoder_spec}, @@ -78,36 +74,17 @@ def get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim): }, ) - return vision_submodule_spec - def get_audio_submodules_spec(hidden_size): """Get the submodule spec for the audio modality.""" - - class AudioEncoderWrapper(torch.nn.Module): - """Generic wrapper for audio encoder models that extracts last_hidden_state.""" - - def __init__(self, model_name="openai/whisper-tiny"): - super().__init__() - # Local tiny Whisper model with random weights - self.encoder = WhisperModel(WhisperConfig()).encoder - - def forward(self, input_features): - # Process through encoder and extract last_hidden_state - with torch.no_grad(): - return self.encoder(input_features).last_hidden_state - - # Audio modality configuration - audio_encoder_spec = ModuleSpec( - module=AudioEncoderWrapper, params={"model_name": "openai/whisper-tiny"} - ) + audio_encoder_spec = ModuleSpec(module=AudioEncoderWrapper, params={}) audio_projection_spec = ModuleSpec( module=nn.Linear, params={"in_features": 384, "out_features": hidden_size}, # Whisper tiny hidden size ) - audio_submodule_spec = ModuleSpec( + return ModuleSpec( module=AudioModalitySubmodules, submodules={ "encoders": {"whisper_encoder": audio_encoder_spec}, @@ -115,8 +92,6 @@ def forward(self, input_features): }, ) - return audio_submodule_spec - def get_language_model_spec(hidden_size, vocab_size, seq_len): """Get the language model spec.""" @@ -124,7 +99,7 @@ def get_language_model_spec(hidden_size, vocab_size, seq_len): num_layers=2, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True ) language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - language_model_spec = ModuleSpec( + return ModuleSpec( module=GPTModel, params={ "config": lm_config, @@ -135,55 +110,44 @@ def get_language_model_spec(hidden_size, vocab_size, seq_len): "post_process": True, }, ) - return language_model_spec def get_avlm_mimo_model( hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids ): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - audio_submodule_spec = get_audio_submodules_spec(hidden_size) - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec, "audio": audio_submodule_spec}, + language_model_spec=get_language_model_spec(hidden_size, vocab_size, seq_len), + modality_submodules_spec={ + "images": get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim), + "audio": get_audio_submodules_spec(hidden_size), + }, special_token_ids=special_token_ids, ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model + return MimoModel(mimo_config) def get_vlm_mimo_model( hidden_size, vocab_size, seq_len, img_h, img_w, patch_dim, special_token_ids ): - language_model_spec = get_language_model_spec(hidden_size, vocab_size, seq_len) - vision_submodule_spec = get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) - mimo_config = MimoModelConfig( - language_model_spec=language_model_spec, - modality_submodules_spec={"images": vision_submodule_spec}, + language_model_spec=get_language_model_spec(hidden_size, vocab_size, seq_len), + modality_submodules_spec={ + "images": get_vision_submodules_spec(hidden_size, img_h, img_w, patch_dim) + }, special_token_ids=special_token_ids, ) - - # Create MIMO model - mimo_model = MimoModel(mimo_config) - return mimo_model + return MimoModel(mimo_config) class TestMimoModel: """Test the MimoModel class.""" def setup_method(self, method): - '''setup env and model''' try: Utils.initialize_model_parallel(1, 1) - except Exception as e: - print(f"Warning: Could not initialize model parallel: {e}") + except Exception: + pass - # Set dimensions self.hidden_size = 64 self.batch_size = 2 self.seq_len = 2048 @@ -191,21 +155,28 @@ def setup_method(self, method): self.img_w = 224 self.patch_dim = 16 self.vocab_size = 48000 - - # Define special token IDs, not in LLM vocab self.special_token_ids = {"images": 50257, "audio": 50258} + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def teardown_method(self, method): - '''teardown env''' try: Utils.destroy_model_parallel() - except Exception as e: - print(f"Warning: Could not destroy model parallel: {e}") + except Exception: + pass - def test_constructor(self): - """Test constructor initialization.""" + def _make_vlm(self): + return get_vlm_mimo_model( + self.hidden_size, + self.vocab_size, + self.seq_len, + self.img_h, + self.img_w, + self.patch_dim, + self.special_token_ids, + ).to(self.device) - mimo_model = get_avlm_mimo_model( + def _make_avlm(self): + return get_avlm_mimo_model( self.hidden_size, self.vocab_size, self.seq_len, @@ -213,247 +184,124 @@ def test_constructor(self): self.img_w, self.patch_dim, self.special_token_ids, + ).to(self.device) + + def _make_input_ids(self): + return torch.randint( + 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device ) - # Move to device - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - mimo_model = mimo_model.to(device) + def _make_position_ids(self): + return ( + torch.arange(self.seq_len, device=self.device).unsqueeze(0).expand(self.batch_size, -1) + ) + + def test_constructor(self): + """Test constructor initialization.""" + mimo_model = self._make_avlm() - # Test that modality submodules were initialized correctly assert "images" in mimo_model.modality_submodules assert "audio" in mimo_model.modality_submodules assert isinstance(mimo_model.modality_submodules["images"], VisionModalitySubmodules) assert isinstance(mimo_model.modality_submodules["audio"], AudioModalitySubmodules) - # Test that language model was initialized - assert hasattr(mimo_model, "language_model") assert isinstance(mimo_model.language_model, GPTModel) - - # Test that special token IDs were set correctly assert mimo_model.special_token_ids == self.special_token_ids def test_get_text_embeddings(self): """Test getting text embeddings.""" - # Create random input and position IDs (within vocab size range) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Get text embeddings + mimo_model = self._make_avlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() + text_embeddings = mimo_model.get_text_embeddings( input_ids, position_ids, self.special_token_ids ) - # Verify shape - # [b*s, h] assert text_embeddings.shape == (self.batch_size * self.seq_len, self.hidden_size) def test_forward_text_only(self): """Test forward pass with only text input.""" - # Create inputs - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - # Run forward pass with explicit parameters outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=None ) - assert outputs is not None - - # Verify output shape assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_forward_with_image_modality(self): """Test forward pass with text and image input.""" - # Calculate expected number of image tokens based on image size and patch dimension - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") expected_img_seq_len = (self.img_h // self.patch_dim) * ( self.img_w // self.patch_dim ) + 1 # +1 for CLS token - # Create a fixed distribution of images: 3 in first sample, 2 in second sample num_images = 5 - images_per_sample = [3, 2] # Must sum to num_images - assert sum(images_per_sample) == num_images - assert len(images_per_sample) == self.batch_size - - # Create images tensor - images = torch.rand( - num_images, 3, self.img_h, self.img_w, device=device - ) # [num_images, 3, h, w] format + images_per_sample = [3, 2] + images = torch.rand(num_images, 3, self.img_h, self.img_w, device=self.device) + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - # Create input_ids with text tokens - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - - # Create position_ids - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) - - # Include image special tokens in input IDs + # Place image special tokens in each batch sample image_token_id = self.special_token_ids["images"] - start_pos = 5 # Start position for image tokens - - # Make sure there's enough space in the sequence for all image tokens in each sample - for b in range(self.batch_size): - tokens_needed = images_per_sample[b] * expected_img_seq_len - assert ( - start_pos + tokens_needed <= self.seq_len - ), f"Sequence length too short for image tokens in sample {b}" - - # Add image tokens to each batch sample according to its number of images + start_pos = 5 for b in range(self.batch_size): tokens_in_this_batch = images_per_sample[b] * expected_img_seq_len - if tokens_in_this_batch > 0: - input_ids[b, start_pos : start_pos + tokens_in_this_batch] = image_token_id + input_ids[b, start_pos : start_pos + tokens_in_this_batch] = image_token_id - # Create modality inputs using the new structure modality_inputs = {"images": {"clip_encoder": {"x": images}}} - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Run forward pass with new interface + mimo_model = self._make_vlm() outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs ) - assert outputs is not None - - # Verify output shape assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_forward_with_image_and_audio_modality(self): """Test forward pass with text, image, and audio input.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + mimo_model = self._make_avlm() - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Calculate image sequence length img_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 - encoder_down_sampling = 2 - - # Create simple audio input (30 sec) - mel_bins = 80 # Whisper uses 80 mel bins + mel_bins = 80 time_bins = 3000 # 30 seconds of audio at 10ms per frame - audio_features = torch.rand(2, mel_bins, time_bins, device=device) + audio_seq_len = math.ceil(time_bins / encoder_down_sampling) - # Calculate audio sequence length using Whisper's formula - audio_seq_len = math.ceil(time_bins / encoder_down_sampling) # 1500 tokens + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - # Create batch data - batch_size = 2 - seq_len = self.seq_len - - # Create input_ids with special tokens - input_ids = torch.randint(0, self.vocab_size, (batch_size, seq_len), device=device) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) - - # Add special tokens at specific positions + # Place image and audio special tokens start_pos = 5 image_token_id = self.special_token_ids["images"] audio_token_id = self.special_token_ids["audio"] - - # Place image tokens followed by audio tokens in each batch item - for i in range(batch_size): - # Add image tokens + for i in range(self.batch_size): input_ids[i, start_pos : start_pos + img_seq_len] = image_token_id - # Add audio tokens after a gap - input_ids[ - i, start_pos + img_seq_len + 10 : start_pos + img_seq_len + 10 + audio_seq_len - ] = audio_token_id + audio_start = start_pos + img_seq_len + 10 + input_ids[i, audio_start : audio_start + audio_seq_len] = audio_token_id - # Prepare modality inputs modality_inputs = { "images": { - "clip_encoder": {"x": torch.rand(2, 3, self.img_h, self.img_w, device=device)} + "clip_encoder": {"x": torch.rand(2, 3, self.img_h, self.img_w, device=self.device)} + }, + "audio": { + "whisper_encoder": { + "input_features": torch.rand(2, mel_bins, time_bins, device=self.device) + } }, - "audio": {"whisper_encoder": {"input_features": audio_features}}, } - # Run forward pass outputs, _ = mimo_model( input_ids=input_ids, position_ids=position_ids, modality_inputs=modality_inputs ) - - # Verify output shape - assert outputs is not None - assert outputs.shape == (batch_size, seq_len, self.vocab_size) + assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) def test_state_dict(self): """Test state dict methods.""" - # Get state dict - mimo_model = get_avlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) + mimo_model = self._make_avlm() state_dict = mimo_model.state_dict() assert len(state_dict) > 0 + assert any(k.startswith("language_model.") for k in state_dict) + assert any(k.startswith("modality_submodules.") for k in state_dict) - # Make sure we have keys for language model and modality submodules - has_lm_keys = False - has_modality_keys = False - - for key in state_dict.keys(): - if key.startswith("language_model."): - has_lm_keys = True - if key.startswith("modality_submodules."): - has_modality_keys = True - - assert has_lm_keys - assert has_modality_keys - - # Test checkpoint state dict checkpoint_dict = mimo_model.state_dict_for_save_checkpoint() assert len(checkpoint_dict) > 0 @@ -467,12 +315,11 @@ def test_pipeline_model_parallel_assertion(self): pipeline_model_parallel_size=2, pipeline_dtype=torch.float32, ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() language_model_spec_pp2 = ModuleSpec( module=GPTModel, params={ "config": lm_config_pp2, - "transformer_layer_spec": language_layer_spec, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), "vocab_size": self.vocab_size, "max_sequence_length": self.seq_len, "pre_process": True, @@ -490,59 +337,32 @@ def test_pipeline_model_parallel_assertion(self): def test_partition_adapter_none_by_default(self): """Test that partition_adapter is None with default config (no CP/SP).""" - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - # TransformerConfig defaults: context_parallel_size=1, sequence_parallel=False + mimo_model = self._make_vlm() assert mimo_model.partition_adapter is None def test_forward_with_packing_kwargs(self): """Test that packing_kwargs builds PackedSeqParams with qkv_format='thd' and int32 seqlens.""" from megatron.core.packed_seq_params import PackedSeqParams - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # cu_seqlens covering full batch: [0, seq_len, 2*seq_len] cu_seqlens = torch.tensor( - [0, self.seq_len, 2 * self.seq_len], dtype=torch.int64, device=device + [0, self.seq_len, 2 * self.seq_len], dtype=torch.int64, device=self.device ) packing_kwargs = {"cu_seqlens_q": cu_seqlens.clone(), "cu_seqlens_kv": cu_seqlens.clone()} - # Mock get_text_embeddings and align_embeddings_by_token_positions to avoid full forward - text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=device) - combined_emb = torch.zeros(self.seq_len, self.batch_size, self.hidden_size, device=device) + text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=self.device) + combined_emb = torch.zeros( + self.seq_len, self.batch_size, self.hidden_size, device=self.device + ) - # Capture packed_seq_params via a side_effect on language_model.forward. - # Direct assignment (mimo_model.language_model = MagicMock()) is rejected by - # PyTorch because language_model is a registered nn.Module child. captured = {} def capture_lm_forward(*args, **kwargs): captured['packed_seq_params'] = kwargs.get('packed_seq_params') - return torch.zeros(self.batch_size, self.seq_len, self.vocab_size, device=device) + return torch.zeros(self.batch_size, self.seq_len, self.vocab_size, device=self.device) with ( patch.object(mimo_model, 'get_text_embeddings', return_value=text_emb), @@ -558,10 +378,7 @@ def capture_lm_forward(*args, **kwargs): packing_kwargs=packing_kwargs, ) - # Verify language model received a properly constructed PackedSeqParams packed_seq_params = captured['packed_seq_params'] - - assert packed_seq_params is not None assert isinstance(packed_seq_params, PackedSeqParams) assert packed_seq_params.qkv_format == 'thd' assert packed_seq_params.cu_seqlens_q.dtype == torch.int32 @@ -569,41 +386,30 @@ def capture_lm_forward(*args, **kwargs): def test_forward_with_partition_adapter(self): """Test that partition_adapter.shard() is called and embeddings are transposed correctly.""" - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.randint( - 0, self.vocab_size, (self.batch_size, self.seq_len), device=device - ) - position_ids = ( - torch.arange(self.seq_len, device=device).unsqueeze(0).expand(self.batch_size, -1) - ) + mimo_model = self._make_vlm() + input_ids = self._make_input_ids() + position_ids = self._make_position_ids() - mimo_model = get_vlm_mimo_model( - self.hidden_size, - self.vocab_size, - self.seq_len, - self.img_h, - self.img_w, - self.patch_dim, - self.special_token_ids, - ) - mimo_model = mimo_model.to(device) - - # Inject a mock partition adapter that halves the sequence dimension sharded_seq_len = self.seq_len // 2 - sharded_emb = torch.zeros(self.batch_size, sharded_seq_len, self.hidden_size, device=device) + sharded_emb = torch.zeros( + self.batch_size, sharded_seq_len, self.hidden_size, device=self.device + ) mock_adapter = MagicMock() mock_adapter.shard.return_value = (sharded_emb, None, None, None, None) mimo_model.partition_adapter = mock_adapter - text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=device) - # align_embeddings_by_token_positions returns [S, B, H] - combined_emb = torch.zeros(self.seq_len, self.batch_size, self.hidden_size, device=device) + text_emb = torch.zeros(self.batch_size * self.seq_len, self.hidden_size, device=self.device) + combined_emb = torch.zeros( + self.seq_len, self.batch_size, self.hidden_size, device=self.device + ) captured = {} def capture_lm_forward(*args, **kwargs): captured['decoder_input'] = kwargs.get('decoder_input') - return torch.zeros(self.batch_size, sharded_seq_len, self.vocab_size, device=device) + return torch.zeros( + self.batch_size, sharded_seq_len, self.vocab_size, device=self.device + ) with ( patch.object(mimo_model, 'get_text_embeddings', return_value=text_emb), @@ -614,14 +420,9 @@ def capture_lm_forward(*args, **kwargs): ): mimo_model(input_ids=input_ids, position_ids=position_ids, modality_inputs=None) - # shard() should have been called once mock_adapter.shard.assert_called_once() - - # The embeddings passed to shard() must be [B, S, H] (transposed from [S, B, H]) shard_kwargs = mock_adapter.shard.call_args[1] assert shard_kwargs['embeddings'].shape == (self.batch_size, self.seq_len, self.hidden_size) - - # The language model decoder_input must be [S/cp, B, H] (re-transposed after shard) assert captured['decoder_input'].shape == ( sharded_seq_len, self.batch_size, @@ -671,6 +472,9 @@ def setup_method(self, method): self.vocab_size = 48000 self.seq_len = 256 self.batch_size = 2 + self.img_h = 224 + self.img_w = 224 + self.patch_dim = 16 def teardown_method(self, method): try: @@ -683,9 +487,11 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p language_model_spec = get_language_model_spec( self.hidden_size, self.vocab_size, self.seq_len ) - vision_submodule_spec = get_vision_submodules_spec(self.hidden_size, 224, 224, 16) + vision_submodule_spec = get_vision_submodules_spec( + self.hidden_size, self.img_h, self.img_w, self.patch_dim + ) - encoder_offset = 0 if encoder_in_grid else 10 # rank 0 in grid if offset=0 + encoder_offset = 0 if encoder_in_grid else 10 language_offset = 0 if language_in_grid else 10 return MimoModelConfig( @@ -716,9 +522,10 @@ def test_grid_validation_rejects_mismatched_keys(self): language_model_spec = get_language_model_spec( self.hidden_size, self.vocab_size, self.seq_len ) - vision_submodule_spec = get_vision_submodules_spec(self.hidden_size, 224, 224, 16) + vision_submodule_spec = get_vision_submodules_spec( + self.hidden_size, self.img_h, self.img_w, self.patch_dim + ) - # Missing 'images' in grid_map mimo_config = MimoModelConfig( language_model_spec=language_model_spec, modality_submodules_spec={"images": vision_submodule_spec}, @@ -734,16 +541,22 @@ def test_role_determination(self): """Test role correctly identifies modules and stage positions.""" # No grid map = no role model_no_grid = get_vlm_mimo_model( - self.hidden_size, self.vocab_size, self.seq_len, 224, 224, 16, {"images": 50257} + self.hidden_size, + self.vocab_size, + self.seq_len, + self.img_h, + self.img_w, + self.patch_dim, + {"images": 50257}, ) assert model_no_grid.role is None - # Encoder-only rank (language grid excludes rank 0) + # Encoder-only rank model_encoder = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) assert model_encoder.role.has_modality_modules is True assert model_encoder.role.has_language_module is False - # Language-only rank (encoder grid excludes rank 0) + # Language-only rank model_language = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) assert model_language.role.has_modality_modules is False assert model_language.role.has_language_module is True @@ -772,7 +585,7 @@ def test_forward_encoder_only(self): model = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) model = model.to(self.device) - images = torch.rand(2, 3, 224, 224, device=self.device) + images = torch.rand(2, 3, self.img_h, self.img_w, device=self.device) input_ids = torch.randint( 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device ) @@ -780,7 +593,6 @@ def test_forward_encoder_only(self): outputs, _ = model( input_ids=input_ids, modality_inputs={"images": {"clip_encoder": {"x": images}}} ) - assert isinstance(outputs, dict) assert "images" in outputs @@ -789,22 +601,20 @@ def test_forward_language_only(self): model = MimoModel(self._make_config(encoder_in_grid=False, language_in_grid=True)) model = model.to(self.device) - img_seq_len = (224 // 16) * (224 // 16) + 1 + img_seq_len = (self.img_h // self.patch_dim) * (self.img_w // self.patch_dim) + 1 input_ids = torch.randint( 0, self.vocab_size, (self.batch_size, self.seq_len), device=self.device ) - input_ids[:, 5 : 5 + img_seq_len] = 50257 # image tokens + input_ids[:, 5 : 5 + img_seq_len] = 50257 position_ids = ( torch.arange(self.seq_len, device=self.device).unsqueeze(0).expand(self.batch_size, -1) ) - # Simulate encoder output from previous stage encoder_embeddings = torch.randn( self.batch_size * img_seq_len, self.hidden_size, device=self.device ) model.set_input_tensor({"images": encoder_embeddings}) outputs, _ = model(input_ids=input_ids, position_ids=position_ids, modality_inputs=None) - assert isinstance(outputs, torch.Tensor) assert outputs.shape == (self.batch_size, self.seq_len, self.vocab_size) From bf75c2d8d66a6bf51ba96939ff077eb7cd0fb231 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 10:47:43 -0700 Subject: [PATCH 18/30] Make RankRole always-present and simplify encoder dispatch - Add colocated flag and RankRole.all_modules() factory so _determine_role always returns a RankRole (never None) - Remove all `if self.role is not None` guards from _initialize_submodules, _initialize_language_model, and forward() - forward() checks self.role.colocated instead of self.role is None - Rank-not-in-any-grid now raises RuntimeError immediately in _determine_role instead of returning None and failing later - Simplify _forward_encoders: pass both encoder_inputs and hidden_states to submodule, let its is_first_stage flag decide which to use - Update test_role_determination to assert colocated role properties Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/config/role.py | 16 ++++++ megatron/core/models/mimo/model/base.py | 57 +++++++++------------- tests/unit_tests/models/test_mimo_model.py | 6 ++- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 677218f4c24..13a6552e85d 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -32,10 +32,14 @@ class RankRole: this rank participates in. language_module_name: Name of the language module, used to distinguish encoders from the language model. + colocated: If True, all modules run on all ranks (no multi-module PP). + The forward path uses the colocated codepath which supports + PartitionAdapter and PackedSeqParams. """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) language_module_name: Optional[str] = None + colocated: bool = False def __post_init__(self): """Validate that language_module_name is set when modules is non-empty.""" @@ -45,6 +49,18 @@ def __post_init__(self): f"Got modules={list(self.modules.keys())} with language_module_name=None." ) + @classmethod + def all_modules(cls, module_names: List[str], language_module_name: str) -> 'RankRole': + """Create a role for the colocated case: every module, first+last stage.""" + return cls( + modules={ + name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) + for name in module_names + }, + language_module_name=language_module_name, + colocated=True, + ) + @property def has_modality_modules(self) -> bool: """Return True if this rank participates in any modality (non-language) module.""" diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 8df257b7c5e..160f34965a6 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -166,18 +166,13 @@ def _initialize_submodules(self) -> None: initialization on non-last stages (saves memory in pipeline parallelism). """ for modality_name, submodule_spec in self.mimo_config.modality_submodules_spec.items(): - # Skip if we have a role and this module isn't in it - if self.role is not None and modality_name not in self.role.modules: + if modality_name not in self.role.modules: logger.debug(f"Skipping {modality_name} submodule (not in role)") continue - # Determine stage info for this module - is_first_stage = True - is_last_stage = True - if self.role is not None: - stage_info = self.role.modules[modality_name] - is_first_stage = stage_info.is_first_stage - is_last_stage = stage_info.is_last_stage + stage_info = self.role.modules[modality_name] + is_first_stage = stage_info.is_first_stage + is_last_stage = stage_info.is_last_stage submodule_class = submodule_spec.module logger.debug( @@ -197,8 +192,7 @@ def _initialize_language_model(self) -> None: When role is set, only initializes if this rank participates in language module. """ - # Skip if we have a role and don't participate in language module - if self.role is not None and not self.role.has_language_module: + if not self.role.has_language_module: logger.debug("Skipping language model initialization (not in role)") self.language_model = None return @@ -241,15 +235,20 @@ def _validate_grid_map(self) -> None: f"Extra in grid_map: {extra_in_grid}" ) - def _determine_role(self) -> Optional[RankRole]: + def _determine_role(self) -> RankRole: """Determine this rank's role based on grid map. Returns: - RankRole describing which modules this rank participates in, - or None if module_to_grid_map is not set (all modules on all ranks). + RankRole describing which modules this rank participates in. + For the colocated case (no module_to_grid_map), returns a role with + all modules at first+last stage and colocated=True. """ if not self.mimo_config.module_to_grid_map: - return None + # Colocated: all modules on all ranks, single stage + all_module_names = list(self.mimo_config.modality_submodules_spec.keys()) + language_key = self.mimo_config.language_module_key or "_language" + all_module_names.append(language_key) + return RankRole.all_modules(all_module_names, language_key) current_rank = dist.get_rank() modules = {} @@ -278,7 +277,10 @@ def _determine_role(self) -> Optional[RankRole]: modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) if not modules: - return None + raise RuntimeError( + f"Rank {current_rank} is not in any module grid. " + f"Check module_to_grid_map configuration." + ) return RankRole(modules=modules, language_module_name=self.mimo_config.language_module_key) @@ -398,8 +400,7 @@ def forward( # Get any tensors passed via set_input_tensor input_tensors = getattr(self, 'input_tensors', None) - if self.role is None: - # Original behavior: all modules on all ranks + if self.role.colocated: return self._forward_all_modules( input_ids, position_ids, @@ -453,22 +454,10 @@ def _forward_encoders( continue submodule = self.modality_submodules[encoder_name] - - # Determine input based on stage position - if self.role.is_first_stage(encoder_name): - encoder_input = modality_inputs.get(encoder_name) if modality_inputs else None - output = ( - submodule.forward(encoder_inputs=encoder_input) - if encoder_input is not None - else None - ) - else: - hidden_states = input_tensors.get(encoder_name) if input_tensors else None - output = ( - submodule.forward(hidden_states=hidden_states) - if hidden_states is not None - else None - ) + output = submodule.forward( + encoder_inputs=modality_inputs.get(encoder_name) if modality_inputs else None, + hidden_states=input_tensors.get(encoder_name) if input_tensors else None, + ) if output is not None: outputs[encoder_name] = output diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index e29dad10edf..ae75e755a9c 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -539,7 +539,7 @@ def test_grid_validation_rejects_mismatched_keys(self): def test_role_determination(self): """Test role correctly identifies modules and stage positions.""" - # No grid map = no role + # No grid map = colocated role with all modules model_no_grid = get_vlm_mimo_model( self.hidden_size, self.vocab_size, @@ -549,7 +549,9 @@ def test_role_determination(self): self.patch_dim, {"images": 50257}, ) - assert model_no_grid.role is None + assert model_no_grid.role.colocated is True + assert model_no_grid.role.has_language_module is True + assert model_no_grid.role.has_modality_modules is True # Encoder-only rank model_encoder = MimoModel(self._make_config(encoder_in_grid=True, language_in_grid=False)) From 4bb971190b8a2a55cfa35a962087c453e8e1294b Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 10:58:03 -0700 Subject: [PATCH 19/30] Replace configurable language_module_key with fixed LANGUAGE_MODULE_KEY MIMO always has exactly one language model, so the key doesn't need to be configurable. This removes: - language_module_key field from MimoModelConfig - language_module_name field from RankRole - Validation that language_module_key is set - The or "_language" fallback hack in _determine_role Replaced with a single constant LANGUAGE_MODULE_KEY = "language" in config/role.py, used consistently across base.py and tests. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../core/models/mimo/config/base_configs.py | 5 +-- megatron/core/models/mimo/config/role.py | 28 ++++++--------- megatron/core/models/mimo/model/base.py | 35 +++++++------------ tests/unit_tests/models/test_mimo_model.py | 7 ++-- 4 files changed, 27 insertions(+), 48 deletions(-) diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index 31361ecbdaa..7fe405b6454 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -23,10 +23,8 @@ class MimoModelConfig: module_to_grid_map (Optional[Dict[str, Any]]): Dictionary mapping module keys (e.g., "vision", "language") to their corresponding grid configurations for non-colocated pipeline parallelism. + The language model must use the key "language" (LANGUAGE_MODULE_KEY). When None, all modules are assumed to be colocated on the same ranks. - language_module_key (Optional[str]): - The key used to identify the language module in the module_to_grid_map. - Required when module_to_grid_map is provided. kv_format (str): Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim). Default is "sbhd". @@ -43,5 +41,4 @@ class MimoModelConfig: modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) special_token_ids: Dict[str, int] = field(default_factory=dict) module_to_grid_map: Optional[Dict[str, Any]] = None - language_module_key: Optional[str] = None kv_format: str = "sbhd" diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 13a6552e85d..408f22d43b5 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -3,7 +3,11 @@ """Data classes for MIMO rank role management in multi-module pipeline parallelism.""" from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Dict, List + +# Fixed key for the language module in module_to_grid_map and RankRole. +# MIMO always has exactly one language model, so this is not configurable. +LANGUAGE_MODULE_KEY = "language" @dataclass @@ -25,56 +29,44 @@ class RankRole: This class captures the role of a specific rank in a multi-module pipeline parallel setup, tracking which modules the rank participates in and their - stage positions. + stage positions. The language module is always identified by LANGUAGE_MODULE_KEY. Args: modules: Dict mapping module names to their stage info for modules this rank participates in. - language_module_name: Name of the language module, used to distinguish - encoders from the language model. colocated: If True, all modules run on all ranks (no multi-module PP). The forward path uses the colocated codepath which supports PartitionAdapter and PackedSeqParams. """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) - language_module_name: Optional[str] = None colocated: bool = False - def __post_init__(self): - """Validate that language_module_name is set when modules is non-empty.""" - if self.modules and self.language_module_name is None: - raise ValueError( - "language_module_name must be set when modules is non-empty. " - f"Got modules={list(self.modules.keys())} with language_module_name=None." - ) - @classmethod - def all_modules(cls, module_names: List[str], language_module_name: str) -> 'RankRole': + def all_modules(cls, module_names: List[str]) -> 'RankRole': """Create a role for the colocated case: every module, first+last stage.""" return cls( modules={ name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) for name in module_names }, - language_module_name=language_module_name, colocated=True, ) @property def has_modality_modules(self) -> bool: """Return True if this rank participates in any modality (non-language) module.""" - return any(name != self.language_module_name for name in self.modules) + return any(name != LANGUAGE_MODULE_KEY for name in self.modules) @property def has_language_module(self) -> bool: """Return True if this rank participates in the language module.""" - return self.language_module_name is not None and self.language_module_name in self.modules + return LANGUAGE_MODULE_KEY in self.modules @property def modality_module_names(self) -> List[str]: """Return names of modality modules (non-language) this rank participates in.""" - return [name for name in self.modules if name != self.language_module_name] + return [name for name in self.modules if name != LANGUAGE_MODULE_KEY] def is_first_stage(self, module_name: str) -> bool: """Check if this rank is the first stage for a given module.""" diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 160f34965a6..115109f4874 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -8,7 +8,7 @@ import torch.distributed as dist from megatron.core.models.mimo.config import MimoModelConfig -from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole +from megatron.core.models.mimo.config.role import LANGUAGE_MODULE_KEY, ModuleStageInfo, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule @@ -205,9 +205,8 @@ def _initialize_language_model(self) -> None: def _validate_grid_map(self) -> None: """Validate module_to_grid_map consistency with submodule config. - Validates that: - - language_module_key is set when module_to_grid_map is provided - - module_to_grid_map keys exactly match modality_submodules_spec keys + language_module_key + Validates that module_to_grid_map keys exactly match + modality_submodules_spec keys + LANGUAGE_MODULE_KEY. Raises: ValueError: If validation fails. @@ -215,23 +214,16 @@ def _validate_grid_map(self) -> None: if not self.mimo_config.module_to_grid_map: return - # Require language_module_key when using multi-module PP - if self.mimo_config.language_module_key is None: - raise ValueError( - "language_module_key must be set when module_to_grid_map is provided. " - "Specify which module key identifies the language model." - ) - grid_map_keys = set(self.mimo_config.module_to_grid_map.keys()) - submodule_keys = set(self.mimo_config.modality_submodules_spec.keys()) - submodule_keys.add(self.mimo_config.language_module_key) + expected_keys = set(self.mimo_config.modality_submodules_spec.keys()) + expected_keys.add(LANGUAGE_MODULE_KEY) - if grid_map_keys != submodule_keys: - missing_in_grid = submodule_keys - grid_map_keys - extra_in_grid = grid_map_keys - submodule_keys + if grid_map_keys != expected_keys: + missing_in_grid = expected_keys - grid_map_keys + extra_in_grid = grid_map_keys - expected_keys raise ValueError( f"module_to_grid_map keys must match modality_submodules_spec keys + " - f"language_module_key. Missing in grid_map: {missing_in_grid}, " + f"'{LANGUAGE_MODULE_KEY}'. Missing in grid_map: {missing_in_grid}, " f"Extra in grid_map: {extra_in_grid}" ) @@ -246,9 +238,8 @@ def _determine_role(self) -> RankRole: if not self.mimo_config.module_to_grid_map: # Colocated: all modules on all ranks, single stage all_module_names = list(self.mimo_config.modality_submodules_spec.keys()) - language_key = self.mimo_config.language_module_key or "_language" - all_module_names.append(language_key) - return RankRole.all_modules(all_module_names, language_key) + all_module_names.append(LANGUAGE_MODULE_KEY) + return RankRole.all_modules(all_module_names) current_rank = dist.get_rank() modules = {} @@ -282,7 +273,7 @@ def _determine_role(self) -> RankRole: f"Check module_to_grid_map configuration." ) - return RankRole(modules=modules, language_module_name=self.mimo_config.language_module_key) + return RankRole(modules=modules) def set_input_tensor(self, input_tensor): """Set input tensor for pipeline parallelism. @@ -484,7 +475,7 @@ def _forward_language_module( Returns: Language model output (hidden states, logits, or loss depending on stage) """ - lang_name = self.role.language_module_name + lang_name = LANGUAGE_MODULE_KEY if self.role.is_first_stage(lang_name): # First stage: receive encoder embeddings, combine with text, pass to LM diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index ae75e755a9c..2e2e443e473 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -15,6 +15,7 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import LANGUAGE_MODULE_KEY from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -506,7 +507,7 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p pp_rank=pp_rank, pp_size=pp_size, ), - "language": MockGrid( + LANGUAGE_MODULE_KEY: MockGrid( rank_offset=language_offset, size=1, dim_names=["pp"] if pp_size > 1 else [], @@ -514,7 +515,6 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p pp_size=pp_size, ), }, - language_module_key="language", ) def test_grid_validation_rejects_mismatched_keys(self): @@ -530,8 +530,7 @@ def test_grid_validation_rejects_mismatched_keys(self): language_model_spec=language_model_spec, modality_submodules_spec={"images": vision_submodule_spec}, special_token_ids={"images": 50257}, - module_to_grid_map={"language": MockGrid()}, - language_module_key="language", + module_to_grid_map={LANGUAGE_MODULE_KEY: MockGrid()}, ) with pytest.raises(ValueError, match="module_to_grid_map keys must match"): From a01478260216e87479e9257e1dd3df1a5c867e83 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Mon, 2 Feb 2026 16:45:06 -0800 Subject: [PATCH 20/30] Add MimoOptimizer for heterogeneous parallelism Adds optimizer support for MIMO models where different modules (encoder, LLM) may have different DP/TP/PP configurations. Key features: - MimoOptimizer class managing per-module optimizers - True global gradient norm via all_reduce MAX across module boundaries - Module-aware checkpointing (state_dict keyed by module name) - Simple API: get_mimo_optimizer(mimo_model, config) --- megatron/core/models/mimo/__init__.py | 3 + megatron/core/models/mimo/optimizer.py | 220 ++++++++++++++ .../models/test_mimo_1f1b_schedule.py | 24 ++ .../unit_tests/models/test_mimo_optimizer.py | 281 ++++++++++++++++++ 4 files changed, 528 insertions(+) create mode 100644 megatron/core/models/mimo/optimizer.py create mode 100644 tests/unit_tests/models/test_mimo_optimizer.py diff --git a/megatron/core/models/mimo/__init__.py b/megatron/core/models/mimo/__init__.py index 204851c444b..779bf921e1c 100644 --- a/megatron/core/models/mimo/__init__.py +++ b/megatron/core/models/mimo/__init__.py @@ -2,6 +2,7 @@ from megatron.core.models.mimo.config.base_configs import MimoModelConfig from megatron.core.models.mimo.model import MimoModel +from megatron.core.models.mimo.optimizer import MimoOptimizer, get_mimo_optimizer from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.base import ModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -9,6 +10,8 @@ __all__ = [ 'MimoModelConfig', 'MimoModel', + 'MimoOptimizer', + 'get_mimo_optimizer', # Submodule classes 'ModalitySubmodules', 'VisionModalitySubmodules', diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py new file mode 100644 index 00000000000..c967df0c8e6 --- /dev/null +++ b/megatron/core/models/mimo/optimizer.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Optimizer for MIMO models with heterogeneous parallelism.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from megatron.core.optimizer.clip_grads import clip_grad_by_total_norm_fp32 +from megatron.core.optimizer.optimizer import MegatronOptimizer +from megatron.core.optimizer.optimizer_config import OptimizerConfig +from megatron.core.process_groups_config import ProcessGroupCollection + + +@dataclass +class ModuleOptimizerInfo: + """Optimizer info for a single module.""" + + optimizer: Optional[MegatronOptimizer] + grid: Any # HyperCommGrid + pg_collection: Optional[ProcessGroupCollection] + is_active: bool + + +class MimoOptimizer(MegatronOptimizer): + """ + Optimizer for MimoModel with heterogeneous parallelism. + + Each module gets its own optimizer. Global gradient norm is computed + across all modules via all_reduce MAX. + """ + + def __init__( + self, + module_infos: Dict[str, ModuleOptimizerInfo], + config: OptimizerConfig, + ): + self.module_infos = module_infos + self.config = config + self._active_optimizers: List[MegatronOptimizer] = [ + info.optimizer + for info in module_infos.values() + if info.is_active and info.optimizer is not None + ] + self.is_stub_optimizer = len(self._active_optimizers) == 0 + self.optimizer = None # Base class compat + + @torch.no_grad() + def prepare_grads(self) -> bool: + found_inf = False + for opt in self._active_optimizers: + found_inf |= opt.prepare_grads() + return found_inf + + @torch.no_grad() + def get_grad_norm(self) -> float: + """Compute global gradient norm across all modules via all_reduce MAX.""" + num_modules = len(self.module_infos) + norm_sq = torch.zeros(num_modules, device="cuda", dtype=torch.float32) + + for i, (name, info) in enumerate(sorted(self.module_infos.items())): + if info.is_active and info.optimizer: + module_norm = info.optimizer.get_grad_norm() or 0.0 + norm_sq[i] = module_norm**2 + + torch.distributed.all_reduce(norm_sq, op=torch.distributed.ReduceOp.MAX) + return torch.sqrt(norm_sq.sum()).item() + + @torch.no_grad() + def step(self) -> Tuple[bool, Optional[float], Optional[int]]: + found_inf = self.prepare_grads() + if found_inf: + return False, None, None + + grad_norm = self.get_grad_norm() + + # Clip with global norm + for opt in self._active_optimizers: + if getattr(opt, "is_stub_optimizer", False): + continue + params = opt.get_parameters() + if params and opt.config.clip_grad > 0.0: + clip_grad_by_total_norm_fp32( + params, + max_norm=opt.config.clip_grad, + total_norm=grad_norm, + use_decoupled_grad=opt.config.use_precision_aware_optimizer, + ) + + num_zeros = self.count_zeros() if self.config.log_num_zeros_in_grad else None + success = self.step_with_ready_grads() + + return success, grad_norm, num_zeros + + @torch.no_grad() + def step_with_ready_grads(self) -> bool: + success = True + for opt in self._active_optimizers: + success &= opt.step_with_ready_grads() + return success + + def zero_grad(self, set_to_none: bool = True): + for opt in self._active_optimizers: + opt.zero_grad(set_to_none) + + def get_loss_scale(self) -> torch.Tensor: + if self._active_optimizers: + return self._active_optimizers[0].get_loss_scale() + return torch.tensor([1.0], dtype=torch.float32, device="cuda") + + def count_zeros(self) -> int: + return sum(opt.count_zeros() for opt in self._active_optimizers) + + @property + def param_groups(self) -> List[dict]: + groups = [] + for opt in self._active_optimizers: + groups.extend(opt.param_groups) + return groups + + # Checkpointing + + def state_dict(self): + return { + name: info.optimizer.state_dict() if info.is_active and info.optimizer else None + for name, info in self.module_infos.items() + } + + def load_state_dict(self, state_dict: Dict): + for name, info in self.module_infos.items(): + if info.is_active and info.optimizer and state_dict.get(name): + info.optimizer.load_state_dict(state_dict[name]) + + def sharded_state_dict(self, model_sharded_state_dict, is_loading: bool = False, **kwargs): + sharded_state = {} + for name, info in self.module_infos.items(): + if info.is_active and info.optimizer: + sharded_state[name] = info.optimizer.sharded_state_dict( + model_sharded_state_dict, is_loading, **kwargs + ) + return sharded_state + + def reload_model_params(self, state_dict=None): + for opt in self._active_optimizers: + opt.reload_model_params(state_dict) + + +def _get_pg_collection_from_grid(grid) -> ProcessGroupCollection: + """Create ProcessGroupCollection from HyperCommGrid.""" + import torch.distributed as dist + + from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage + + pg = ProcessGroupCollection() + pg.tp = grid.get_pg("tp") + pg.cp = grid.get_pg("cp") + pg.pp = grid.get_pg("pp") + pg.ep = grid.get_pg("ep") + pg.dp = grid.get_pg("dp") + pg.dp_cp = grid.get_pg(["dp", "cp"]) + + # Embedding groups + if pg.pp: + pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + + pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) + embd_pg = dist.new_group(ranks=embd_ranks) + + pg.pos_embd = pos_embd_pg if is_pp_first_stage(pg.pp) else None + pg.embd = embd_pg if (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) else None + + pg.mp = grid.get_pg("tp") + return pg + + +def get_mimo_optimizer( + mimo_model: "MimoModel", + config: OptimizerConfig, +) -> MimoOptimizer: + """Create optimizer for MimoModel with heterogeneous parallelism.""" + from megatron.core.optimizer import get_megatron_optimizer + + grid_map = mimo_model.mimo_config.module_to_grid_map + lang_key = mimo_model.mimo_config.language_module_key + + module_infos: Dict[str, ModuleOptimizerInfo] = {} + + for module_name, grid in grid_map.items(): + is_active = grid.is_rank_in_grid() + + optimizer = None + pg_collection = None + + if is_active: + if module_name == lang_key: + module = mimo_model.language_model + else: + module = mimo_model.modality_submodules.get(module_name) + + if module is not None: + pg_collection = _get_pg_collection_from_grid(grid) + optimizer = get_megatron_optimizer( + config=config, + model_chunks=[module], + pg_collection=pg_collection, + ) + + module_infos[module_name] = ModuleOptimizerInfo( + optimizer=optimizer, + grid=grid, + pg_collection=pg_collection, + is_active=is_active, + ) + + return MimoOptimizer(module_infos, config) diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index b6cd67b6f54..6f366e9ec12 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -20,7 +20,9 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.models.mimo.optimizer import get_mimo_optimizer from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.optimizer.optimizer_config import OptimizerConfig from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.pipeline_parallel.multimodule_communicator import ( MultiModulePipelineCommunicator, @@ -515,6 +517,19 @@ def finalize_grads_func(*args, **kwargs): if isinstance(loss, (int, float)) else loss ) + # Create MimoOptimizer + logger.info(f"[Rank {dist.get_rank()}] Creating MimoOptimizer...") + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + optimizer = get_mimo_optimizer(mimo_model, opt_config) + logger.info(f"[Rank {dist.get_rank()}] MimoOptimizer created with {len(optimizer._active_optimizers)} active optimizers") + logger.info(f"[Rank {dist.get_rank()}] Creating communicator...") communicator = MultiModulePipelineCommunicator( module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} @@ -585,6 +600,10 @@ def loss_func(loss_mask, output_tensor): return output_tensor, partial(loss_func, loss_mask) logger.info(f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches...") + + # Zero gradients before forward/backward + optimizer.zero_grad() + losses = schedule.forward_backward_pipelining_without_interleaving( forward_step_func=step_func, data_iterator=data_iterator, @@ -597,6 +616,11 @@ def loss_func(loss_mask, output_tensor): pg_collection=pg_collection, ) + # Optimizer step with global gradient clipping + logger.info(f"[Rank {dist.get_rank()}] Running optimizer step...") + success, grad_norm, num_zeros = optimizer.step() + logger.info(f"[Rank {dist.get_rank()}] Optimizer step: success={success}, grad_norm={grad_norm}") + # Verify results on last LLM stage if is_rank_in_grid(llm_grid): if is_pp_last_stage(llm_grid.get_pg("pp")): diff --git a/tests/unit_tests/models/test_mimo_optimizer.py b/tests/unit_tests/models/test_mimo_optimizer.py new file mode 100644 index 00000000000..2634bacd56a --- /dev/null +++ b/tests/unit_tests/models/test_mimo_optimizer.py @@ -0,0 +1,281 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for MimoOptimizer. + +Unit tests (no distributed): + pytest tests/unit_tests/models/test_mimo_optimizer.py -v -k "not distributed" + +Integration tests (requires torchrun): + torchrun --nproc_per_node=2 tests/unit_tests/models/test_mimo_optimizer.py +""" + +import pytest +import torch + +from megatron.core.models.mimo.optimizer import ModuleOptimizerInfo, MimoOptimizer + + +class TestModuleOptimizerInfo: + """Tests for ModuleOptimizerInfo dataclass.""" + + def test_create_active(self): + info = ModuleOptimizerInfo( + optimizer=None, + grid=None, + pg_collection=None, + is_active=True, + ) + assert info.is_active is True + + def test_create_inactive(self): + info = ModuleOptimizerInfo( + optimizer=None, + grid=None, + pg_collection=None, + is_active=False, + ) + assert info.is_active is False + + +class TestMimoOptimizerUnit: + """Unit tests for MimoOptimizer (no distributed required).""" + + def test_init_empty(self): + """Test initialization with no active optimizers.""" + from megatron.core.optimizer.optimizer_config import OptimizerConfig + + config = OptimizerConfig(optimizer='adam', lr=1e-4) + module_infos = { + "encoder": ModuleOptimizerInfo(None, None, None, is_active=False), + "language": ModuleOptimizerInfo(None, None, None, is_active=False), + } + opt = MimoOptimizer(module_infos, config) + + assert opt.is_stub_optimizer is True + assert len(opt._active_optimizers) == 0 + + def test_param_groups_empty(self): + """Test param_groups property with no active optimizers.""" + from megatron.core.optimizer.optimizer_config import OptimizerConfig + + config = OptimizerConfig(optimizer='adam', lr=1e-4) + module_infos = {} + opt = MimoOptimizer(module_infos, config) + + assert opt.param_groups == [] + + def test_state_dict_empty(self): + """Test state_dict with no active optimizers.""" + from megatron.core.optimizer.optimizer_config import OptimizerConfig + + config = OptimizerConfig(optimizer='adam', lr=1e-4) + module_infos = { + "encoder": ModuleOptimizerInfo(None, None, None, is_active=False), + } + opt = MimoOptimizer(module_infos, config) + + state = opt.state_dict() + assert "encoder" in state + assert state["encoder"] is None + + +# ============================================================================ +# Integration tests (require torchrun) +# ============================================================================ + +def run_distributed_test(): + """Run distributed integration test.""" + import torch.distributed as dist + + from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig + from megatron.core.hyper_comm_grid import HyperCommGrid + from megatron.core.models.mimo import MimoModel, MimoModelConfig, get_mimo_optimizer + from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules + from megatron.core.optimizer.optimizer_config import OptimizerConfig + from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_config import TransformerConfig + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + from megatron.core.transformer.transformer_block import TransformerBlock + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage + + def create_grid(offset=0, tp=1, pp=1, dp=1): + grid = HyperCommGrid( + shape=[tp, 1, pp, dp, 1], + dim_names=["tp", "cp", "pp", "dp", "ep"], + rank_offset=offset, + backend="nccl", + ) + grid.create_pg(["tp"]) + grid.create_pg(["cp"]) + grid.create_pg(["pp"]) + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["ep"]) + return grid + + def get_pg_collection(grid): + pg = ProcessGroupCollection() + pg.tp = grid.get_pg("tp") + pg.cp = grid.get_pg("cp") + pg.pp = grid.get_pg("pp") + pg.ep = grid.get_pg("ep") + pg.dp = grid.get_pg("dp") + pg.dp_cp = grid.get_pg(["dp", "cp"]) + + if pg.pp: + pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) + pos_embd_pg = dist.new_group(ranks=[pp_ranks[0]]) + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + embd_pg = dist.new_group(ranks=embd_ranks) + pg.pos_embd = pos_embd_pg if is_pp_first_stage(pg.pp) else None + pg.embd = embd_pg if (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) else None + + return pg + + # Initialize distributed + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + print(f"[Rank {rank}/{world_size}] Starting MimoOptimizer test") + + # Create grids: encoder on rank 0, LLM on rank 1 + encoder_grid = create_grid(offset=0, tp=1, pp=1, dp=1) + llm_grid = create_grid(offset=1, tp=1, pp=1, dp=1) + + hidden_size = 64 + num_layers = 2 + vocab_size = 1000 + seq_len = 64 + + # Create model specs + encoder_pg = get_pg_collection(encoder_grid) + llm_pg = get_pg_collection(llm_grid) + + lm_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + + encoder_config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=8, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + ) + + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": True, + "post_process": True, + "pg_collection": llm_pg, + }, + ) + + encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": encoder_config, + "spec": get_gpt_layer_with_transformer_engine_spec(), + "pg_collection": encoder_pg, + "pre_process": True, + "post_process": True, + }, + ) + + vision_spec = ModuleSpec( + module=VisionModalitySubmodules, + submodules={"encoders": {"clip": encoder_spec}, "input_projections": []}, + ) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_spec}, + special_token_ids={"images": 50257}, + module_to_grid_map={"images": encoder_grid, "language": llm_grid}, + language_module_key="language", + ) + + mimo_model = MimoModel(mimo_config) + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + + # Wrap with DDP + ddp_config = DistributedDataParallelConfig( + overlap_grad_reduce=True, + bucket_size=10000, + use_distributed_optimizer=True, + ) + + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=llm_pg, + ) + + if "images" in mimo_model.modality_submodules and mimo_model.modality_submodules["images"] is not None: + submodule = mimo_model.modality_submodules["images"] + mimo_model.modality_submodules["images"] = DistributedDataParallel( + config=submodule.encoders['clip'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=encoder_pg, + ) + + # Create optimizer + opt_config = OptimizerConfig( + optimizer='adam', + lr=1e-4, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, + ) + + optimizer = get_mimo_optimizer(mimo_model, opt_config) + + print(f"[Rank {rank}] Created optimizer with {len(optimizer._active_optimizers)} active optimizers") + + # Verify structure + assert "images" in optimizer.module_infos + assert "language" in optimizer.module_infos + + if rank == 0: + assert optimizer.module_infos["images"].is_active is True + assert optimizer.module_infos["language"].is_active is False + else: + assert optimizer.module_infos["images"].is_active is False + assert optimizer.module_infos["language"].is_active is True + + # Test zero_grad and basic operations + optimizer.zero_grad() + + # Test state dict + state = optimizer.state_dict() + assert "images" in state + assert "language" in state + + print(f"[Rank {rank}] MimoOptimizer test PASSED") + + dist.destroy_process_group() + + +if __name__ == "__main__": + run_distributed_test() From 5f0a3091c68ea43f53245a07425d36132667b378 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 3 Feb 2026 10:44:44 -0800 Subject: [PATCH 21/30] Simplify _get_pg_collection_for_optimizer to only fetch required groups - Rename _get_pg_collection_from_grid to _get_pg_collection_for_optimizer - Remove embedding group creation (not needed by optimizer) - Fix mp group to use ["tp", "pp"] instead of just "tp" - Add missing optimizer groups: tp_ep_pp, expt_dp - Update test to create required process groups --- megatron/core/models/mimo/optimizer.py | 60 +++++++++++-------- .../unit_tests/models/test_mimo_optimizer.py | 4 ++ 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index c967df0c8e6..2da3795520f 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -146,35 +146,44 @@ def reload_model_params(self, state_dict=None): opt.reload_model_params(state_dict) -def _get_pg_collection_from_grid(grid) -> ProcessGroupCollection: - """Create ProcessGroupCollection from HyperCommGrid.""" - import torch.distributed as dist - - from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage - +def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: + """Create ProcessGroupCollection from HyperCommGrid for optimizer use. + + Only fetches process groups required by the optimizer. Assumes all groups + are pre-created in the grid via grid.create_pg() - does not create any new groups. + + The following groups must be pre-created in the grid before calling this function: + grid.create_pg(["dp"]) + grid.create_pg(["dp", "cp"]) + grid.create_pg(["tp"]) + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) + + Args: + grid: HyperCommGrid with pre-created process groups. + + Returns: + ProcessGroupCollection containing optimizer-required groups: + - dp: Data parallel group + - dp_cp: Data parallel with context parallel + - tp: Tensor parallel group + - mp: Model parallel group (tp × pp) + - tp_ep_pp: Expert tensor-model-pipeline group + - expt_dp: Expert data parallel group + """ pg = ProcessGroupCollection() - pg.tp = grid.get_pg("tp") - pg.cp = grid.get_pg("cp") - pg.pp = grid.get_pg("pp") - pg.ep = grid.get_pg("ep") + + # Core groups needed by optimizer pg.dp = grid.get_pg("dp") pg.dp_cp = grid.get_pg(["dp", "cp"]) + pg.tp = grid.get_pg("tp") + pg.mp = grid.get_pg(["tp", "pp"]) - # Embedding groups - if pg.pp: - pp_ranks = sorted(dist.get_process_group_ranks(pg.pp)) - pos_embd_ranks = [pp_ranks[0]] - embd_ranks = [pp_ranks[0]] - if pp_ranks[-1] != pp_ranks[0]: - embd_ranks.append(pp_ranks[-1]) - - pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) - embd_pg = dist.new_group(ranks=embd_ranks) - - pg.pos_embd = pos_embd_pg if is_pp_first_stage(pg.pp) else None - pg.embd = embd_pg if (is_pp_last_stage(pg.pp) or is_pp_first_stage(pg.pp)) else None + # Expert groups + pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) + pg.expt_dp = grid.get_pg(["dp", "ep"]) - pg.mp = grid.get_pg("tp") return pg @@ -194,7 +203,7 @@ def get_mimo_optimizer( is_active = grid.is_rank_in_grid() optimizer = None - pg_collection = None + pg_collection = _get_pg_collection_for_optimizer(grid) if is_active: if module_name == lang_key: @@ -203,7 +212,6 @@ def get_mimo_optimizer( module = mimo_model.modality_submodules.get(module_name) if module is not None: - pg_collection = _get_pg_collection_from_grid(grid) optimizer = get_megatron_optimizer( config=config, model_chunks=[module], diff --git a/tests/unit_tests/models/test_mimo_optimizer.py b/tests/unit_tests/models/test_mimo_optimizer.py index 2634bacd56a..a1d90202814 100644 --- a/tests/unit_tests/models/test_mimo_optimizer.py +++ b/tests/unit_tests/models/test_mimo_optimizer.py @@ -113,6 +113,10 @@ def create_grid(offset=0, tp=1, pp=1, dp=1): grid.create_pg(["dp"]) grid.create_pg(["dp", "cp"]) grid.create_pg(["ep"]) + # Required by _get_pg_collection_for_optimizer + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) return grid def get_pg_collection(grid): From 81855a13a943b052f87681ae2097dcc9d9df5b59 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 3 Feb 2026 13:10:40 -0800 Subject: [PATCH 22/30] Fix MimoOptimizer to fail fast on missing modules - Replace .get() with direct indexing for modality_submodules access - When is_active=True, module MUST exist; using .get() silently hid bugs - Direct indexing raises KeyError immediately if module missing - Add is_current_rank_in_grid() method to HyperCommGrid - Format code for consistency Co-Authored-By: Claude Sonnet 4.5 --- megatron/core/hyper_comm_grid.py | 9 +++++++++ megatron/core/models/mimo/optimizer.py | 24 ++++++------------------ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/megatron/core/hyper_comm_grid.py b/megatron/core/hyper_comm_grid.py index 9b5cc6cfaf5..48787d5fa5b 100644 --- a/megatron/core/hyper_comm_grid.py +++ b/megatron/core/hyper_comm_grid.py @@ -262,3 +262,12 @@ def _order_dims(self, dims: Union[str, list[str]]) -> Tuple[list[str], str]: unique_group_key = "-".join(ordered_dims) return ordered_dims, unique_group_key + + def is_current_rank_in_grid(self) -> bool: + """Check if the current rank belongs to this grid. + + Returns: + True if the current rank is within [rank_offset, rank_offset + size). + """ + rank = dist.get_rank() + return self.rank_offset <= rank < self.rank_offset + self.size diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index 2da3795520f..eb82d28ebb5 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -31,11 +31,7 @@ class MimoOptimizer(MegatronOptimizer): across all modules via all_reduce MAX. """ - def __init__( - self, - module_infos: Dict[str, ModuleOptimizerInfo], - config: OptimizerConfig, - ): + def __init__(self, module_infos: Dict[str, ModuleOptimizerInfo], config: OptimizerConfig): self.module_infos = module_infos self.config = config self._active_optimizers: List[MegatronOptimizer] = [ @@ -187,10 +183,7 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: return pg -def get_mimo_optimizer( - mimo_model: "MimoModel", - config: OptimizerConfig, -) -> MimoOptimizer: +def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> MimoOptimizer: """Create optimizer for MimoModel with heterogeneous parallelism.""" from megatron.core.optimizer import get_megatron_optimizer @@ -200,7 +193,7 @@ def get_mimo_optimizer( module_infos: Dict[str, ModuleOptimizerInfo] = {} for module_name, grid in grid_map.items(): - is_active = grid.is_rank_in_grid() + is_active = grid.is_current_rank_in_grid() optimizer = None pg_collection = _get_pg_collection_for_optimizer(grid) @@ -209,20 +202,15 @@ def get_mimo_optimizer( if module_name == lang_key: module = mimo_model.language_model else: - module = mimo_model.modality_submodules.get(module_name) + module = mimo_model.modality_submodules[module_name] if module is not None: optimizer = get_megatron_optimizer( - config=config, - model_chunks=[module], - pg_collection=pg_collection, + config=config, model_chunks=[module], pg_collection=pg_collection ) module_infos[module_name] = ModuleOptimizerInfo( - optimizer=optimizer, - grid=grid, - pg_collection=pg_collection, - is_active=is_active, + optimizer=optimizer, grid=grid, pg_collection=pg_collection, is_active=is_active ) return MimoOptimizer(module_infos, config) From 80ce148b780c060122632a3739488a73c01e2614 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 26 Feb 2026 21:48:50 -0800 Subject: [PATCH 23/30] [MIMO] Configure distributed optimizer process group in optimizer setup --- megatron/core/models/mimo/optimizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index eb82d28ebb5..f8ccf58abbf 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -180,6 +180,10 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) pg.expt_dp = grid.get_pg(["dp", "ep"]) + # Distributed optimizer group (same as dp_cp when num_distributed_optimizer_instances == 1) + # FIXME: Yash - handle multiple optimizer instances + pg.intra_dist_opt = grid.get_pg(["dp", "cp"]) + return pg @@ -206,7 +210,10 @@ def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> Mimo if module is not None: optimizer = get_megatron_optimizer( - config=config, model_chunks=[module], pg_collection=pg_collection + config=config, + model_chunks=[module], + pg_collection=pg_collection, + use_gloo_process_groups=False, ) module_infos[module_name] = ModuleOptimizerInfo( From 3e97b3228bdb1204a8dcd8bb2cbdb7b2f4fc0a60 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 14:39:34 -0700 Subject: [PATCH 24/30] Improve MIMO pipeline abstraction and type safety MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace colocated bool with PipelineMode enum (UNIFIED, NON_COLOCATED, COLOCATED) for clear forward path dispatch - Move _determine_role and _validate_grid_map from MimoModel to RankRole.from_grid_map classmethod — MimoModel no longer knows about grids - Rename LANGUAGE_MODULE_KEY to MIMO_LANGUAGE_MODULE_KEY - Type module_to_grid_map as Dict[str, HyperCommGrid] instead of Dict[str, Any] - Remove torch.distributed import from base.py (moved to role.py) Co-Authored-By: Claude Opus 4.6 (1M context) --- .../core/models/mimo/config/base_configs.py | 11 +- megatron/core/models/mimo/config/role.py | 115 ++++++++++++++-- megatron/core/models/mimo/model/base.py | 127 +++--------------- tests/unit_tests/models/test_mimo_model.py | 16 +-- 4 files changed, 138 insertions(+), 131 deletions(-) diff --git a/megatron/core/models/mimo/config/base_configs.py b/megatron/core/models/mimo/config/base_configs.py index 7fe405b6454..a92484a5a48 100644 --- a/megatron/core/models/mimo/config/base_configs.py +++ b/megatron/core/models/mimo/config/base_configs.py @@ -2,8 +2,9 @@ import warnings from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Dict, Optional +from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.transformer.spec_utils import ModuleSpec @@ -20,10 +21,10 @@ class MimoModelConfig: Dictionary mapping modality names to their special token IDs. For example, {"vision": -200, "audio":32000}, these represent placeholders in the input_ids to insert the modality embeddings at the correct positions. - module_to_grid_map (Optional[Dict[str, Any]]): + module_to_grid_map (Optional[Dict[str, HyperCommGrid]]): Dictionary mapping module keys (e.g., "vision", "language") to their - corresponding grid configurations for non-colocated pipeline parallelism. - The language model must use the key "language" (LANGUAGE_MODULE_KEY). + corresponding HyperCommGrid configurations for non-colocated pipeline + parallelism. The language model must use the key MIMO_LANGUAGE_MODULE_KEY. When None, all modules are assumed to be colocated on the same ranks. kv_format (str): Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim). @@ -40,5 +41,5 @@ class MimoModelConfig: language_model_spec: ModuleSpec = field(default_factory=ModuleSpec) modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) special_token_ids: Dict[str, int] = field(default_factory=dict) - module_to_grid_map: Optional[Dict[str, Any]] = None + module_to_grid_map: Optional[Dict[str, HyperCommGrid]] = None kv_format: str = "sbhd" diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 408f22d43b5..0b3c11e7034 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -2,12 +2,44 @@ """Data classes for MIMO rank role management in multi-module pipeline parallelism.""" +import logging from dataclasses import dataclass, field +from enum import Enum from typing import Dict, List +import torch.distributed as dist + +from megatron.core.hyper_comm_grid import HyperCommGrid + +logger = logging.getLogger(__name__) + # Fixed key for the language module in module_to_grid_map and RankRole. # MIMO always has exactly one language model, so this is not configurable. -LANGUAGE_MODULE_KEY = "language" +MIMO_LANGUAGE_MODULE_KEY = "language" + + +class PipelineMode(Enum): + """Pipeline mode for MIMO multi-module parallelism. + + Determines how modules are distributed across ranks and which + forward path is used. + + UNIFIED: No module_to_grid_map. All modules share same ranks and + parallelism. Uses the unified forward path (_forward_all_modules). + + NON_COLOCATED: module_to_grid_map is set with non-overlapping rank + ranges. Each rank runs EITHER encoder(s) OR the language model. + Uses role-based dispatch with separate forward paths. + + COLOCATED: (future) module_to_grid_map is set with overlapping rank + ranges. Encoder(s) and language model share ranks but have + different parallelism configs. Uses role-based dispatch but + allows both module types on the same rank. + """ + + UNIFIED = "unified" + NON_COLOCATED = "non_colocated" + COLOCATED = "colocated" @dataclass @@ -29,44 +61,103 @@ class RankRole: This class captures the role of a specific rank in a multi-module pipeline parallel setup, tracking which modules the rank participates in and their - stage positions. The language module is always identified by LANGUAGE_MODULE_KEY. + stage positions. The language module is always identified by MIMO_LANGUAGE_MODULE_KEY. Args: modules: Dict mapping module names to their stage info for modules this rank participates in. - colocated: If True, all modules run on all ranks (no multi-module PP). - The forward path uses the colocated codepath which supports - PartitionAdapter and PackedSeqParams. + mode: Pipeline mode determining forward path dispatch. """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) - colocated: bool = False + mode: PipelineMode = PipelineMode.UNIFIED @classmethod - def all_modules(cls, module_names: List[str]) -> 'RankRole': - """Create a role for the colocated case: every module, first+last stage.""" + def unified(cls, module_names: List[str]) -> 'RankRole': + """Create a role for the unified case: every module, first+last stage.""" return cls( modules={ name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) for name in module_names }, - colocated=True, + mode=PipelineMode.UNIFIED, ) + @classmethod + def from_grid_map( + cls, module_to_grid_map: Dict[str, HyperCommGrid], modality_module_names: List[str] + ) -> 'RankRole': + """Create a role from a module-to-grid mapping for non-colocated PP. + + Determines which modules the current rank participates in and its + pipeline stage position within each module. + + Args: + module_to_grid_map: Dict mapping module names to HyperCommGrid objects. + Must contain keys matching modality_module_names + MIMO_LANGUAGE_MODULE_KEY. + modality_module_names: List of modality module names (e.g., ["images", "audio"]). + + Returns: + RankRole for the current rank. + + Raises: + ValueError: If grid map keys don't match expected module names. + RuntimeError: If current rank is not in any module grid. + """ + # Validate keys + expected_keys = set(modality_module_names) | {MIMO_LANGUAGE_MODULE_KEY} + grid_keys = set(module_to_grid_map.keys()) + if grid_keys != expected_keys: + raise ValueError( + f"module_to_grid_map keys must match modality module names + " + f"'{MIMO_LANGUAGE_MODULE_KEY}'. Missing: {expected_keys - grid_keys}, " + f"Extra: {grid_keys - expected_keys}" + ) + + current_rank = dist.get_rank() + modules = {} + + for module_name, grid in module_to_grid_map.items(): + if not (grid.rank_offset <= current_rank < grid.rank_offset + grid.size): + continue + + if "pp" not in grid.dim_names: + modules[module_name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) + continue + + pp_group = grid.get_pg("pp") + pp_rank = pp_group.rank() + pp_size = pp_group.size() + is_first = pp_rank == 0 + is_last = pp_rank == pp_size - 1 + logger.info( + f"[RankRole.from_grid_map] Rank {current_rank}: module={module_name}, " + f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}" + ) + modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) + + if not modules: + raise RuntimeError( + f"Rank {current_rank} is not in any module grid. " + f"Check module_to_grid_map configuration." + ) + + return cls(modules=modules, mode=PipelineMode.NON_COLOCATED) + @property def has_modality_modules(self) -> bool: """Return True if this rank participates in any modality (non-language) module.""" - return any(name != LANGUAGE_MODULE_KEY for name in self.modules) + return any(name != MIMO_LANGUAGE_MODULE_KEY for name in self.modules) @property def has_language_module(self) -> bool: """Return True if this rank participates in the language module.""" - return LANGUAGE_MODULE_KEY in self.modules + return MIMO_LANGUAGE_MODULE_KEY in self.modules @property def modality_module_names(self) -> List[str]: """Return names of modality modules (non-language) this rank participates in.""" - return [name for name in self.modules if name != LANGUAGE_MODULE_KEY] + return [name for name in self.modules if name != MIMO_LANGUAGE_MODULE_KEY] def is_first_stage(self, module_name: str) -> bool: """Check if this rank is the first stage for a given module.""" diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 115109f4874..fff4c8cdca5 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -5,10 +5,9 @@ from typing import Any, Dict, Optional import torch -import torch.distributed as dist from megatron.core.models.mimo.config import MimoModelConfig -from megatron.core.models.mimo.config.role import LANGUAGE_MODULE_KEY, ModuleStageInfo, RankRole +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, PipelineMode, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule @@ -57,8 +56,11 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - ) self.mimo_config = mimo_config - self._validate_grid_map() - self.role = self._determine_role() + modality_names = list(mimo_config.modality_submodules_spec.keys()) + if mimo_config.module_to_grid_map: + self.role = RankRole.from_grid_map(mimo_config.module_to_grid_map, modality_names) + else: + self.role = RankRole.unified(modality_names + [MIMO_LANGUAGE_MODULE_KEY]) # Use special token IDs from the config self.special_token_ids = ( @@ -67,9 +69,6 @@ def __init__(self, mimo_config: MimoModelConfig, cp_group=None, tp_group=None) - # Extract language model config for partition adapter language_config = mimo_config.language_model_spec.params['config'] - assert ( - language_config.pipeline_model_parallel_size == 1 - ), "Pipeline parallelism is not supported in MimoModel" max_seq_len = mimo_config.language_model_spec.params.get('max_sequence_length', 4096) self.partition_adapter: Optional[PartitionAdapter] = None @@ -160,10 +159,8 @@ def align_embeddings_by_token_positions( def _initialize_submodules(self) -> None: """Initialize modality submodules from the ModuleSpec configurations. - Only modalities present in the config will be instantiated. When role is set, only initializes submodules this rank participates in. - Stage info is passed to from_spec() to conditionally skip projection - initialization on non-last stages (saves memory in pipeline parallelism). + Stage info is passed to from_spec() to conditionally skip projection. """ for modality_name, submodule_spec in self.mimo_config.modality_submodules_spec.items(): if modality_name not in self.role.modules: @@ -202,79 +199,6 @@ def _initialize_language_model(self) -> None: ) self.language_model = build_module(self.mimo_config.language_model_spec) - def _validate_grid_map(self) -> None: - """Validate module_to_grid_map consistency with submodule config. - - Validates that module_to_grid_map keys exactly match - modality_submodules_spec keys + LANGUAGE_MODULE_KEY. - - Raises: - ValueError: If validation fails. - """ - if not self.mimo_config.module_to_grid_map: - return - - grid_map_keys = set(self.mimo_config.module_to_grid_map.keys()) - expected_keys = set(self.mimo_config.modality_submodules_spec.keys()) - expected_keys.add(LANGUAGE_MODULE_KEY) - - if grid_map_keys != expected_keys: - missing_in_grid = expected_keys - grid_map_keys - extra_in_grid = grid_map_keys - expected_keys - raise ValueError( - f"module_to_grid_map keys must match modality_submodules_spec keys + " - f"'{LANGUAGE_MODULE_KEY}'. Missing in grid_map: {missing_in_grid}, " - f"Extra in grid_map: {extra_in_grid}" - ) - - def _determine_role(self) -> RankRole: - """Determine this rank's role based on grid map. - - Returns: - RankRole describing which modules this rank participates in. - For the colocated case (no module_to_grid_map), returns a role with - all modules at first+last stage and colocated=True. - """ - if not self.mimo_config.module_to_grid_map: - # Colocated: all modules on all ranks, single stage - all_module_names = list(self.mimo_config.modality_submodules_spec.keys()) - all_module_names.append(LANGUAGE_MODULE_KEY) - return RankRole.all_modules(all_module_names) - - current_rank = dist.get_rank() - modules = {} - - for module_name, grid in self.mimo_config.module_to_grid_map.items(): - # Check if current rank is in this grid - if not (grid.rank_offset <= current_rank < grid.rank_offset + grid.size): - continue - - # Check if PP dimension exists - if "pp" not in grid.dim_names: - # No PP dimension means single stage (both first and last) - modules[module_name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True) - continue - - # Get PP process group and determine stage - pp_group = grid.get_pg("pp") - pp_rank = pp_group.rank() - pp_size = pp_group.size() - is_first = pp_rank == 0 - is_last = pp_rank == pp_size - 1 - logger.info( - f"[_determine_role] Rank {current_rank}: module={module_name}, " - f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}" - ) - modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last) - - if not modules: - raise RuntimeError( - f"Rank {current_rank} is not in any module grid. " - f"Check module_to_grid_map configuration." - ) - - return RankRole(modules=modules) - def set_input_tensor(self, input_tensor): """Set input tensor for pipeline parallelism. @@ -294,14 +218,11 @@ def set_input_tensor(self, input_tensor): self.input_tensors = input_tensor return - # Backward compatibility: single tensor or list if isinstance(input_tensor, list): input_tensor = input_tensor[0] - # Store as input_tensors for consistency self.input_tensors = input_tensor - # Also delegate to language model for backward compatibility if self.language_model is not None and hasattr(self.language_model, 'set_input_tensor'): self.language_model.set_input_tensor(input_tensor) @@ -391,7 +312,7 @@ def forward( # Get any tensors passed via set_input_tensor input_tensors = getattr(self, 'input_tensors', None) - if self.role.colocated: + if self.role.mode == PipelineMode.UNIFIED: return self._forward_all_modules( input_ids, position_ids, @@ -402,27 +323,21 @@ def forward( packing_kwargs, ) - # Guard: colocated encoders + language module is not supported - if self.role.has_modality_modules and self.role.has_language_module: - raise ValueError( - "Invalid configuration: Colocated encoders and language module on the same " - "rank is not supported in multi-module pipeline parallelism. Use separate " - "grids for encoders and language module, or disable multi-module PP by not " - "setting module_to_grid_map." - ) + if self.role.mode == PipelineMode.NON_COLOCATED: + if self.role.has_modality_modules: + return self._forward_encoders(modality_inputs, input_tensors), loss_mask - if self.role.has_modality_modules: - return self._forward_encoders(modality_inputs, input_tensors), loss_mask + if self.role.has_language_module: + return ( + self._forward_language_module( + input_ids, position_ids, attention_mask, labels, input_tensors + ), + loss_mask, + ) - if self.role.has_language_module: - return ( - self._forward_language_module( - input_ids, position_ids, attention_mask, labels, input_tensors - ), - loss_mask, - ) + raise RuntimeError(f"Rank has no modules assigned in role: {self.role}") - raise RuntimeError(f"Rank has no modules assigned in role: {self.role}") + raise NotImplementedError(f"Pipeline mode {self.role.mode} is not yet supported") def _forward_encoders( self, @@ -475,7 +390,7 @@ def _forward_language_module( Returns: Language model output (hidden states, logits, or loss depending on stage) """ - lang_name = LANGUAGE_MODULE_KEY + lang_name = MIMO_LANGUAGE_MODULE_KEY if self.role.is_first_stage(lang_name): # First stage: receive encoder embeddings, combine with text, pass to LM diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 2e2e443e473..e1ed3a5a8a8 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -15,7 +15,7 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig -from megatron.core.models.mimo.config.role import LANGUAGE_MODULE_KEY +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, PipelineMode from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -306,8 +306,8 @@ def test_state_dict(self): checkpoint_dict = mimo_model.state_dict_for_save_checkpoint() assert len(checkpoint_dict) > 0 - def test_pipeline_model_parallel_assertion(self): - """Test that MimoModel raises AssertionError when pipeline_model_parallel_size > 1.""" + def test_pipeline_model_parallel_accepted(self): + """Test that MimoModel accepts pipeline_model_parallel_size > 1.""" lm_config_pp2 = TransformerConfig( num_layers=2, hidden_size=self.hidden_size, @@ -333,8 +333,8 @@ def test_pipeline_model_parallel_assertion(self): special_token_ids=self.special_token_ids, ) - with pytest.raises(AssertionError, match="Pipeline parallelism is not supported"): - MimoModel(mimo_config) + model = MimoModel(mimo_config) + assert model is not None def test_partition_adapter_none_by_default(self): """Test that partition_adapter is None with default config (no CP/SP).""" @@ -507,7 +507,7 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p pp_rank=pp_rank, pp_size=pp_size, ), - LANGUAGE_MODULE_KEY: MockGrid( + MIMO_LANGUAGE_MODULE_KEY: MockGrid( rank_offset=language_offset, size=1, dim_names=["pp"] if pp_size > 1 else [], @@ -530,7 +530,7 @@ def test_grid_validation_rejects_mismatched_keys(self): language_model_spec=language_model_spec, modality_submodules_spec={"images": vision_submodule_spec}, special_token_ids={"images": 50257}, - module_to_grid_map={LANGUAGE_MODULE_KEY: MockGrid()}, + module_to_grid_map={MIMO_LANGUAGE_MODULE_KEY: MockGrid()}, ) with pytest.raises(ValueError, match="module_to_grid_map keys must match"): @@ -548,7 +548,7 @@ def test_role_determination(self): self.patch_dim, {"images": 50257}, ) - assert model_no_grid.role.colocated is True + assert model_no_grid.role.mode == PipelineMode.UNIFIED assert model_no_grid.role.has_language_module is True assert model_no_grid.role.has_modality_modules is True From 03207bcc584f4a208700f68af2e3ecaf093feef6 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Thu, 19 Mar 2026 15:00:32 -0700 Subject: [PATCH 25/30] Rename PipelineMode to ModuleLayout Better describes the spatial arrangement of modules across ranks without overloading "pipeline" which has specific meaning in Megatron. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/config/role.py | 8 ++++---- megatron/core/models/mimo/model/base.py | 6 +++--- tests/unit_tests/models/test_mimo_model.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/megatron/core/models/mimo/config/role.py b/megatron/core/models/mimo/config/role.py index 0b3c11e7034..3ca153aa1b0 100644 --- a/megatron/core/models/mimo/config/role.py +++ b/megatron/core/models/mimo/config/role.py @@ -18,7 +18,7 @@ MIMO_LANGUAGE_MODULE_KEY = "language" -class PipelineMode(Enum): +class ModuleLayout(Enum): """Pipeline mode for MIMO multi-module parallelism. Determines how modules are distributed across ranks and which @@ -70,7 +70,7 @@ class RankRole: """ modules: Dict[str, ModuleStageInfo] = field(default_factory=dict) - mode: PipelineMode = PipelineMode.UNIFIED + mode: ModuleLayout = ModuleLayout.UNIFIED @classmethod def unified(cls, module_names: List[str]) -> 'RankRole': @@ -80,7 +80,7 @@ def unified(cls, module_names: List[str]) -> 'RankRole': name: ModuleStageInfo(is_first_stage=True, is_last_stage=True) for name in module_names }, - mode=PipelineMode.UNIFIED, + mode=ModuleLayout.UNIFIED, ) @classmethod @@ -142,7 +142,7 @@ def from_grid_map( f"Check module_to_grid_map configuration." ) - return cls(modules=modules, mode=PipelineMode.NON_COLOCATED) + return cls(modules=modules, mode=ModuleLayout.NON_COLOCATED) @property def has_modality_modules(self) -> bool: diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index fff4c8cdca5..d38b3d639ad 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -7,7 +7,7 @@ import torch from megatron.core.models.mimo.config import MimoModelConfig -from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, PipelineMode, RankRole +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout, RankRole from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import MegatronModule @@ -312,7 +312,7 @@ def forward( # Get any tensors passed via set_input_tensor input_tensors = getattr(self, 'input_tensors', None) - if self.role.mode == PipelineMode.UNIFIED: + if self.role.mode == ModuleLayout.UNIFIED: return self._forward_all_modules( input_ids, position_ids, @@ -323,7 +323,7 @@ def forward( packing_kwargs, ) - if self.role.mode == PipelineMode.NON_COLOCATED: + if self.role.mode == ModuleLayout.NON_COLOCATED: if self.role.has_modality_modules: return self._forward_encoders(modality_inputs, input_tensors), loss_mask diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index e1ed3a5a8a8..1a5f372832c 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -15,7 +15,7 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig -from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, PipelineMode +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY, ModuleLayout from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.submodules.audio import AudioModalitySubmodules from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules @@ -548,7 +548,7 @@ def test_role_determination(self): self.patch_dim, {"images": 50257}, ) - assert model_no_grid.role.mode == PipelineMode.UNIFIED + assert model_no_grid.role.mode == ModuleLayout.UNIFIED assert model_no_grid.role.has_language_module is True assert model_no_grid.role.has_modality_modules is True From 22fbeba521f296e99c417eab915c171e302e3306 Mon Sep 17 00:00:00 2001 From: ykarnati Date: Fri, 20 Mar 2026 07:43:54 -0700 Subject: [PATCH 26/30] Fix non-colocated forward pass and rewrite MIMO 1F1B schedule tests base.py: - Fix set_input_tensor: unwrap schedule's list wrapper before checking dict type, and unwrap single-element lists in dict values (P2P recv returns [tensor] for VPP compat) - Fix set_input_tensor for DDP: use unwrap_model to call set_input_tensor on underlying GPTModel through DDP wrapper - Remove pipeline_model_parallel_size == 1 assertion (contradicts non-colocated PP goal) test_mimo_1f1b_schedule.py: - Convert from standalone script to pytest class (TestMimo1F1BSchedule) - Add grid tracking + cleanup (destroy_all_grids, teardown_method) - Fix dist.new_group desync: create_all_embedding_groups upfront so all ranks participate in collective new_group calls - Fix embedding groups: set embd=None for encoder ranks (no shared word embeddings to sync in finalize_model_grads) - Fix NVTE env vars: clear conftest's NVTE_FLASH_ATTN=0 before GPTModel creation (LanguageModule asserts these match backend) - Use MIMO_LANGUAGE_MODULE_KEY, 6-dim grid shape with expt_dp - Cache pg_collections to avoid PG leaks in finalize_grads_func - Add BridgeCommunicator.destroy_broadcast_pgs() to teardown Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/model/base.py | 22 +- .../models/test_mimo_1f1b_schedule.py | 662 +++++++++--------- 2 files changed, 328 insertions(+), 356 deletions(-) diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index d38b3d639ad..49e2fd42116 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -213,14 +213,20 @@ def set_input_tensor(self, input_tensor): Returns: None """ + # The schedule wraps input_tensor in a list (schedules.py:415-416), + # so unwrap first before checking type. + if isinstance(input_tensor, list): + input_tensor = input_tensor[0] + # Store dict input for multi-module PP if isinstance(input_tensor, dict): - self.input_tensors = input_tensor + # P2P recv may return [tensor] (list) for VPP compat — unwrap to tensor + self.input_tensors = { + k: v[0] if isinstance(v, list) and len(v) == 1 else v + for k, v in input_tensor.items() + } return - if isinstance(input_tensor, list): - input_tensor = input_tensor[0] - self.input_tensors = input_tensor if self.language_model is not None and hasattr(self.language_model, 'set_input_tensor'): @@ -425,9 +431,11 @@ def _forward_language_module( # Non-first stage: receive hidden states from previous LM stage hidden_states = input_tensors.get(lang_name) if input_tensors else None - # Set input tensor on language model for PP - if hidden_states is not None and hasattr(self.language_model, 'set_input_tensor'): - self.language_model.set_input_tensor(hidden_states) + # Set input tensor on language model for PP (unwrap DDP to reach GPTModel) + if hidden_states is not None: + underlying_lm = unwrap_model(self.language_model) + if hasattr(underlying_lm, 'set_input_tensor'): + underlying_lm.set_input_tensor(hidden_states) lm_output = self.language_model( input_ids=None, diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index b6cd67b6f54..766899f1404 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -3,14 +3,17 @@ """Integration tests for MIMO model with 1F1B pipeline schedule. Run with: - torchrun --nproc_per_node=2 tests/unit_tests/models/test_mimo_1f1b_schedule.py + uv run python -m torch.distributed.run --nproc-per-node=2 -m pytest tests/unit_tests/models/test_mimo_1f1b_schedule.py -v """ import logging -from typing import Dict +from contextlib import ExitStack, contextmanager +from functools import partial +import pytest import torch import torch.distributed as dist +from packaging import version import megatron.core.pipeline_parallel.schedules as schedule from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig @@ -19,18 +22,21 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.core.models.mimo.model.base import MimoModel from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules from megatron.core.models.vision.multimodal_projector import MultimodalProjector -from megatron.core.pipeline_parallel.multimodule_communicator import ( - MultiModulePipelineCommunicator, -) +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage -from megatron.core.process_groups_config import MultiModuleProcessGroupCollection, ProcessGroupCollection +from megatron.core.process_groups_config import ( + MultiModuleProcessGroupCollection, + ProcessGroupCollection, +) from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils try: from megatron.core.extensions.transformer_engine import ( @@ -41,20 +47,22 @@ TEColumnParallelLinear = None TERowParallelLinear = None -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ============================================================================ -# Helper Functions +# Helper Functions (with grid tracking and PG caching from edc8159) # ============================================================================ +_active_grids: list = [] +_embedding_pg_cache: dict = {} + def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): """Create a HyperCommGrid with specified parallelism.""" grid = HyperCommGrid( - shape=[tp, cp, pp, dp, 1], - dim_names=["tp", "cp", "pp", "dp", "ep"], + shape=[tp, cp, pp, dp, 1, 1], # [tp, cp, pp, dp, ep, expt_dp] + dim_names=["tp", "cp", "pp", "dp", "ep", "expt_dp"], rank_offset=offset, backend="nccl", ) @@ -64,9 +72,20 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): grid.create_pg(["dp"]) grid.create_pg(["dp", "cp"]) grid.create_pg(["ep"]) + grid.create_pg(["expt_dp"]) + _active_grids.append(grid) return grid +def destroy_all_grids(): + """Destroy all tracked grids and bridge communicator PGs.""" + for grid in _active_grids: + grid.destroy() + _active_grids.clear() + _embedding_pg_cache.clear() + BridgeCommunicator.destroy_broadcast_pgs() + + def get_pg_collection(grid): """Get ProcessGroupCollection from grid.""" pg_collection = ProcessGroupCollection() @@ -76,36 +95,73 @@ def get_pg_collection(grid): pg_collection.ep = grid.get_pg("ep") pg_collection.dp = grid.get_pg("dp") pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + pg_collection.expt_dp = grid.get_pg("expt_dp") return pg_collection -def add_embedding_groups(pg_collection): - """Add embedding groups to process group collection.""" +def create_all_embedding_groups(grids): + """Create embedding PGs for all grids upfront. + + dist.new_group is a collective — ALL ranks must call it, even non-members. + We create all embedding groups in a consistent order across all ranks to + avoid hangs from asymmetric new_group calls. + + Args: + grids: List of all HyperCommGrids that need embedding groups. + """ + for grid in grids: + pp_group = grid.get_pg("pp") + if not pp_group: + continue + + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + cache_key = tuple(pp_ranks) + + if cache_key not in _embedding_pg_cache: + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + _embedding_pg_cache[cache_key] = ( + dist.new_group(ranks=pos_embd_ranks), + dist.new_group(ranks=embd_ranks), + ) + + +def add_embedding_groups(pg_collection, is_language_model=False): + """Add cached embedding groups to a process group collection. + + Must call create_all_embedding_groups() first to ensure PGs exist. + + Args: + pg_collection: ProcessGroupCollection to add embedding groups to. + is_language_model: If True, set embd group for word embedding sync. + """ if not pg_collection.pp: return pg_collection pp_ranks = sorted(dist.get_process_group_ranks(pg_collection.pp)) - pos_embd_ranks = [pp_ranks[0]] - embd_ranks = [pp_ranks[0]] - if pp_ranks[-1] != pp_ranks[0]: - embd_ranks.append(pp_ranks[-1]) - - pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) - embd_pg = dist.new_group(ranks=embd_ranks) + cache_key = tuple(pp_ranks) + pos_embd_pg, embd_pg = _embedding_pg_cache[cache_key] pg_collection.pos_embd = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None - pg_collection.embd = ( - embd_pg - if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) - else None - ) + + if is_language_model: + pg_collection.embd = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + else: + # Encoder submodules have no shared word embeddings to sync + pg_collection.embd = None return pg_collection -def get_pg_collection_with_embedding_groups(grid): - """Get ProcessGroupCollection with embedding groups.""" - return add_embedding_groups(get_pg_collection(grid)) +def get_pg_collection_with_embedding_groups(grid, is_language_model=False): + """Get ProcessGroupCollection with embedding groups (PGs must be pre-created).""" + return add_embedding_groups(get_pg_collection(grid), is_language_model=is_language_model) def is_rank_in_grid(grid): @@ -120,26 +176,12 @@ def is_rank_in_grid(grid): def get_language_model_spec( - num_layers: int, - hidden_size: int, - num_attention_heads: int, - vocab_size: int, - seq_len: int, - pg_collection: ProcessGroupCollection, + num_layers, hidden_size, num_attention_heads, vocab_size, seq_len, pg_collection ): """Get the language model spec.""" pp_rank = dist.get_rank(pg_collection.pp) pp_size = dist.get_world_size(pg_collection.pp) - pre_process = (pp_rank == 0) - post_process = (pp_rank == pp_size - 1) - - logger.info( - f"[get_language_model_spec] Rank {dist.get_rank()}: PP rank={pp_rank}/{pp_size}, " - f"pre_process={pre_process}, post_process={post_process}" - ) - tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 - pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 lm_config = TransformerConfig( num_layers=num_layers, @@ -155,29 +197,23 @@ def get_language_model_spec( cross_entropy_loss_fusion=True, cross_entropy_fusion_impl='te', ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() - language_model_spec = ModuleSpec( + return ModuleSpec( module=GPTModel, params={ "config": lm_config, - "transformer_layer_spec": language_layer_spec, + "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), "vocab_size": vocab_size, "max_sequence_length": seq_len, - "pre_process": pre_process, - "post_process": post_process, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), "pg_collection": pg_collection, }, ) - return language_model_spec -def get_projection_config(hidden_size: int) -> TransformerConfig: +def get_projection_config(hidden_size): """Return a TransformerConfig for the vision projection MLP.""" - cfg = TransformerConfig( - num_layers=1, - hidden_size=hidden_size, - num_attention_heads=1, - ) + cfg = TransformerConfig(num_layers=1, hidden_size=hidden_size, num_attention_heads=1) cfg.ffn_hidden_size = hidden_size cfg.bias_activation_fusion = True cfg.add_bias_linear = True @@ -185,41 +221,25 @@ def get_projection_config(hidden_size: int) -> TransformerConfig: return cfg -def get_projection_layer_spec() -> ModuleSpec: +def get_projection_layer_spec(): """Layer spec for the vision-projection MLP.""" if TEColumnParallelLinear is None or TERowParallelLinear is None: raise RuntimeError("TEColumnParallelLinear and TERowParallelLinear are required") return ModuleSpec( module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ), + submodules=MLPSubmodules(linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear), ) def get_vision_submodules_spec( - num_layers: int, - hidden_size: int, - num_attention_heads: int, - language_hidden_size: int, - pg_collection: ProcessGroupCollection, + num_layers, hidden_size, num_attention_heads, language_hidden_size, pg_collection ): """Get the submodule spec for the vision modality.""" - vision_layer_spec = get_gpt_layer_with_transformer_engine_spec() + from megatron.core.transformer.transformer_block import TransformerBlock tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 - - # Calculate pre/post process based on PP rank (same as language model spec) pp_rank = dist.get_rank(pg_collection.pp) - pre_process = (pp_rank == 0) - post_process = (pp_rank == pp_size - 1) - - logger.info( - f"[get_vision_submodules_spec] Rank {dist.get_rank()}: PP rank={pp_rank}/{pp_size}, " - f"pre_process={pre_process}, post_process={post_process}" - ) vision_config = TransformerConfig( num_layers=num_layers, @@ -237,10 +257,10 @@ def get_vision_submodules_spec( module=TransformerBlock, params={ "config": vision_config, - "spec": vision_layer_spec, + "spec": get_gpt_layer_with_transformer_engine_spec(), "pg_collection": pg_collection, - "pre_process": pre_process, - "post_process": post_process, + "pre_process": (pp_rank == 0), + "post_process": (pp_rank == pp_size - 1), }, ) @@ -255,7 +275,7 @@ def get_vision_submodules_spec( }, ) - vision_submodule_spec = ModuleSpec( + return ModuleSpec( module=VisionModalitySubmodules, submodules={ "encoders": {"clip_encoder": vision_encoder_spec}, @@ -263,93 +283,68 @@ def get_vision_submodules_spec( }, ) - return vision_submodule_spec - def get_mimo_model( - encoder_name: str, - language_module_name: str, - encoder_grid: HyperCommGrid, - llm_grid: HyperCommGrid, - hidden_size: int, - num_layers: int, - vocab_size: int, - seq_len: int, + encoder_name, encoder_grid, llm_grid, hidden_size, num_layers, vocab_size, seq_len ): """Create MIMO model with TransformerBlock encoder and GPTModel LLM.""" - language_pg_collection = get_pg_collection_with_embedding_groups(llm_grid) - vision_pg_collection = get_pg_collection_with_embedding_groups(encoder_grid) + language_pg = get_pg_collection_with_embedding_groups(llm_grid, is_language_model=True) + vision_pg = get_pg_collection_with_embedding_groups(encoder_grid, is_language_model=False) - # Always create full specs on all ranks (POC pattern) language_model_spec = get_language_model_spec( num_layers=num_layers, hidden_size=hidden_size, num_attention_heads=8, vocab_size=vocab_size, seq_len=seq_len, - pg_collection=language_pg_collection, + pg_collection=language_pg, ) - vision_submodule_spec = get_vision_submodules_spec( num_layers=num_layers, hidden_size=hidden_size, num_attention_heads=8, language_hidden_size=hidden_size, - pg_collection=vision_pg_collection, + pg_collection=vision_pg, ) - module_to_grid_map = { - encoder_name: encoder_grid, - language_module_name: llm_grid, - } - topology = { - encoder_name: [language_module_name], - language_module_name: [], - } + module_to_grid_map = {encoder_name: encoder_grid, MIMO_LANGUAGE_MODULE_KEY: llm_grid} + topology = {encoder_name: [MIMO_LANGUAGE_MODULE_KEY], MIMO_LANGUAGE_MODULE_KEY: []} mimo_config = MimoModelConfig( language_model_spec=language_model_spec, modality_submodules_spec={encoder_name: vision_submodule_spec}, special_token_ids={encoder_name: 50257}, module_to_grid_map=module_to_grid_map, - language_module_key=language_module_name, ) - logger.info(f"[Rank {dist.get_rank()}] Creating MimoModel...") mimo_model = MimoModel(mimo_config) - logger.info(f"[Rank {dist.get_rank()}] MimoModel created successfully") - mimo_model.to(torch.device("cuda")).to(torch.bfloat16) # Wrap with DDP ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=True, - bucket_size=10000, - use_distributed_optimizer=True, + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True ) if mimo_model.language_model is not None: - logger.info(f"[Rank {dist.get_rank()}] Wrapping language_model with DDP") mimo_model.language_model = DistributedDataParallel( config=mimo_model.language_model.config, ddp_config=ddp_config, module=mimo_model.language_model, - pg_collection=language_pg_collection, + pg_collection=language_pg, ) if encoder_name in mimo_model.modality_submodules: submodule = mimo_model.modality_submodules[encoder_name] if submodule is not None: - logger.info(f"[Rank {dist.get_rank()}] Wrapping {encoder_name} submodule with DDP") submodule = DistributedDataParallel( config=submodule.encoders['clip_encoder'].config, ddp_config=ddp_config, module=submodule, - pg_collection=vision_pg_collection, + pg_collection=vision_pg, ) mimo_model.modality_submodules[encoder_name] = submodule - return mimo_model, module_to_grid_map, topology + return mimo_model, module_to_grid_map, topology, language_pg, vision_pg # ============================================================================ @@ -358,60 +353,57 @@ def get_mimo_model( class DataIterator: - """Simple data iterator for testing. - - Returns batches matching the POC's MockVLMDataset structure: - - input_ids: [batch_size, seq_length] with image_seq_length image tokens at start - - labels: [batch_size, seq_length] - - loss_mask: [batch_size, seq_length] - - position_ids: [batch_size, seq_length] - - modality_inputs: {modality_name: {encoder_name: {'hidden_states': tensor, 'attention_mask': None}}} - """ - - def __init__(self, hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name, - image_token_id=50257, image_seq_length=None): + """Simple data iterator returning VLM-like batches.""" + + def __init__( + self, + hidden_size, + seq_length, + micro_batch_size, + vocab_size, + encoder_name, + image_token_id=50257, + image_seq_length=None, + ): self.hidden_size = hidden_size self.seq_length = seq_length self.micro_batch_size = micro_batch_size self.vocab_size = vocab_size self.encoder_name = encoder_name self.image_token_id = image_token_id - # Use half the sequence for image tokens by default self.image_seq_length = image_seq_length or (seq_length // 2) def __iter__(self): return self def __next__(self): - # Create encoder input: [image_seq_length, batch_size, hidden_size] - # This matches the number of image tokens in input_ids encoder_hidden_states = torch.randn( - self.image_seq_length, self.micro_batch_size, self.hidden_size, - device='cuda', dtype=torch.bfloat16 + self.image_seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, ) - # Create input_ids with image tokens at the beginning (like MockVLMDataset) - # Shape: [batch_size, seq_length] image_tokens = torch.full( (self.micro_batch_size, self.image_seq_length), self.image_token_id, - dtype=torch.long, device='cuda' + dtype=torch.long, + device='cuda', ) text_tokens = torch.randint( - 1, self.vocab_size, # Avoid 0 (pad token) + 1, + self.vocab_size, (self.micro_batch_size, self.seq_length - self.image_seq_length), - device='cuda' + device='cuda', ) input_ids = torch.cat([image_tokens, text_tokens], dim=1) - # Create labels (copy of input_ids, with image tokens set to -100) labels = input_ids.clone() labels[input_ids == self.image_token_id] = -100 - # Create loss_mask (0 for image tokens, 1 for text tokens) loss_mask = torch.ones( - self.micro_batch_size, self.seq_length, - device='cuda', dtype=torch.float32 + self.micro_batch_size, self.seq_length, device='cuda', dtype=torch.float32 ) loss_mask[input_ids == self.image_token_id] = 0.0 @@ -419,16 +411,13 @@ def __next__(self): "input_ids": input_ids, "labels": labels, "loss_mask": loss_mask, - "position_ids": torch.arange( - self.seq_length, device='cuda' - ).unsqueeze(0).expand(self.micro_batch_size, -1).clone(), - # modality_inputs structure from POC + "position_ids": torch.arange(self.seq_length, device='cuda') + .unsqueeze(0) + .expand(self.micro_batch_size, -1) + .clone(), "modality_inputs": { self.encoder_name: { - "clip_encoder": { - 'hidden_states': encoder_hidden_states, - 'attention_mask': None, - } + "clip_encoder": {'hidden_states': encoder_hidden_states, 'attention_mask': None} } }, } @@ -440,39 +429,46 @@ def __next__(self): def run_mimo_1f1b_test( - encoder_tp: int, - encoder_pp: int, - encoder_dp: int, - encoder_offset: int, - llm_tp: int, - llm_pp: int, - llm_dp: int, - llm_offset: int, - hidden_size: int = 256, - num_layers: int = 2, - vocab_size: int = 1000, - seq_length: int = 64, - micro_batch_size: int = 2, - num_microbatches: int = 4, + encoder_tp, + encoder_pp, + encoder_dp, + encoder_offset, + llm_tp, + llm_pp, + llm_dp, + llm_offset, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, ): """Run MIMO model through 1F1B schedule and verify.""" + # Clear NVTE env vars that the conftest set_env fixture sets to '0'. + # GPTModel (LanguageModule) asserts these are unset or match the attention backend. + import os + + os.environ.pop('NVTE_FLASH_ATTN', None) + os.environ.pop('NVTE_FUSED_ATTN', None) + os.environ.pop('NVTE_UNFUSED_ATTN', None) + encoder_name = "images" - language_module_name = "language_module" - logger.info(f"[Rank {dist.get_rank()}] Creating grids...") encoder_grid = create_hypercomm_grid( offset=encoder_offset, tp=encoder_tp, cp=1, pp=encoder_pp, dp=encoder_dp ) - llm_grid = create_hypercomm_grid( - offset=llm_offset, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp - ) + llm_grid = create_hypercomm_grid(offset=llm_offset, tp=llm_tp, cp=1, pp=llm_pp, dp=llm_dp) + + # Create all embedding PGs upfront — dist.new_group is a collective that + # requires ALL ranks to participate, so we must create them before any + # rank-specific pg_collection calls. + create_all_embedding_groups([encoder_grid, llm_grid]) torch.manual_seed(12345) - logger.info(f"[Rank {dist.get_rank()}] Creating MIMO model...") - mimo_model, module_to_grid_map, topology = get_mimo_model( + mimo_model, module_to_grid_map, topology, language_pg, vision_pg = get_mimo_model( encoder_name=encoder_name, - language_module_name=language_module_name, encoder_grid=encoder_grid, llm_grid=llm_grid, hidden_size=hidden_size, @@ -481,94 +477,73 @@ def run_mimo_1f1b_test( seq_len=seq_length, ) - # Add schedule-related functions to the model's existing config (TransformerConfig) - # Don't replace it with ModelParallelConfig - schedule expects TransformerConfig attributes + # Build schedule functions using pre-created pg_collections (no leaks) + @contextmanager def no_sync_func(): - from contextlib import contextmanager, ExitStack - - @contextmanager - def combined_no_sync(): - with ExitStack() as stack: - if mimo_model.language_model is not None: - stack.enter_context(mimo_model.language_model.no_sync()) - for submodule in mimo_model.modality_submodules.values(): - if submodule is not None: - stack.enter_context(submodule.no_sync()) - yield - - return combined_no_sync() + with ExitStack() as stack: + if mimo_model.language_model is not None: + stack.enter_context(mimo_model.language_model.no_sync()) + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + stack.enter_context(submodule.no_sync()) + yield def finalize_grads_func(*args, **kwargs): if mimo_model.language_model is not None: - llm_pg = get_pg_collection_with_embedding_groups(llm_grid) - finalize_model_grads([mimo_model.language_model], num_tokens=None, pg_collection=llm_pg) + finalize_model_grads( + [mimo_model.language_model], num_tokens=None, pg_collection=language_pg + ) for submodule in mimo_model.modality_submodules.values(): if submodule is not None: - encoder_pg = get_pg_collection_with_embedding_groups(encoder_grid) - finalize_model_grads([submodule], num_tokens=None, pg_collection=encoder_pg) + finalize_model_grads([submodule], num_tokens=None, pg_collection=vision_pg) - # Add schedule functions to existing model config mimo_model.config.no_sync_func = no_sync_func mimo_model.config.finalize_model_grads_func = finalize_grads_func mimo_model.config.grad_scale_func = lambda loss: ( torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) - if isinstance(loss, (int, float)) else loss + if isinstance(loss, (int, float)) + else loss ) - logger.info(f"[Rank {dist.get_rank()}] Creating communicator...") communicator = MultiModulePipelineCommunicator( module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} ) - # Create data iterator on: - # - Encoder's first PP stage (needs modality_inputs) - # - LLM's first PP stage (needs input_ids for embeddings) - # - LLM's last PP stage (needs labels for loss) + # Create data iterator on ranks that need it data_iterator = None - - encoder_needs_data = ( - is_rank_in_grid(encoder_grid) and - is_pp_first_stage(encoder_grid.get_pg("pp")) + encoder_needs_data = is_rank_in_grid(encoder_grid) and is_pp_first_stage( + encoder_grid.get_pg("pp") ) - llm_needs_data = ( - is_rank_in_grid(llm_grid) and - (is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp"))) + llm_needs_data = is_rank_in_grid(llm_grid) and ( + is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp")) ) - if encoder_needs_data or llm_needs_data: - logger.info(f"[Rank {dist.get_rank()}] Creating data iterator (encoder={encoder_needs_data}, llm={llm_needs_data})") - data_iterator = DataIterator(hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name) + data_iterator = DataIterator( + hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name + ) - # Build MultiModuleProcessGroupCollection - # Only include pg_collections for modules this rank participates in + # Build MultiModuleProcessGroupCollection (reuse pre-created pg_collections) module_pgs = {} + language_model_module_name = None if is_rank_in_grid(encoder_grid): - module_pgs[encoder_name] = get_pg_collection_with_embedding_groups(encoder_grid) + module_pgs[encoder_name] = vision_pg if is_rank_in_grid(llm_grid): - module_pgs[language_module_name] = get_pg_collection_with_embedding_groups(llm_grid) - - # Set language_model_module_name only if this rank participates in LLM - lang_module_name = language_module_name if is_rank_in_grid(llm_grid) else None + module_pgs[MIMO_LANGUAGE_MODULE_KEY] = language_pg + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY pg_collection = MultiModuleProcessGroupCollection( - module_pgs=module_pgs, - language_model_module_name=lang_module_name, + module_pgs=module_pgs, language_model_module_name=language_model_module_name ) def step_func(data_iterator, model): - from functools import partial - def loss_func(loss_mask, output_tensor): - """Loss function matching POC pattern.""" if output_tensor is None: return torch.tensor(0.0, device='cuda', requires_grad=True), {'loss_reduced': 0.0} - # Handle dict output (from encoder or intermediate LLM stages) if isinstance(output_tensor, dict): - if language_module_name in output_tensor: - output = output_tensor[language_module_name] - else: - output = list(output_tensor.values())[0] if output_tensor else None + output = output_tensor.get( + MIMO_LANGUAGE_MODULE_KEY, next(iter(output_tensor.values()), None) + ) else: output = output_tensor @@ -579,12 +554,9 @@ def loss_func(loss_mask, output_tensor): return loss, {'loss_reduced': loss} batch = next(data_iterator) if data_iterator is not None else {'input_ids': None} - # MimoModel.forward() returns (output_tensor, loss_mask) tuple output_tensor, loss_mask = model(**batch) - # Return only output_tensor, bind loss_mask to loss_func via partial return output_tensor, partial(loss_func, loss_mask) - logger.info(f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches...") losses = schedule.forward_backward_pipelining_without_interleaving( forward_step_func=step_func, data_iterator=data_iterator, @@ -598,130 +570,122 @@ def loss_func(loss_mask, output_tensor): ) # Verify results on last LLM stage - if is_rank_in_grid(llm_grid): - if is_pp_last_stage(llm_grid.get_pg("pp")): - logger.info(f"[Rank {dist.get_rank()}] Last LLM stage - got {len(losses)} losses") - assert len(losses) > 0, "Expected losses on last LLM stage" - for loss_dict in losses: - assert 'loss_reduced' in loss_dict, "Expected 'loss_reduced' in loss dict" + if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): + assert len(losses) > 0, "Expected losses on last LLM stage" + for loss_dict in losses: + assert 'loss_reduced' in loss_dict - logger.info(f"[Rank {dist.get_rank()}] Test completed successfully!") return losses -def get_test_configs(): - """Get predefined test configurations for different GPU counts. +# ============================================================================ +# Tests +# ============================================================================ + - Returns: - Dict mapping world_size to list of test configurations. - """ - return { - # 2 GPUs: Encoder PP=1, LLM PP=1 (baseline) - 2: [ - { - "name": "baseline_2gpu", - "encoder_tp": 1, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, - "llm_tp": 1, "llm_pp": 1, "llm_dp": 1, "llm_offset": 1, - "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, - "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, - }, - ], - # 4 GPUs: Encoder PP=1, LLM PP=3 (tests keyed output fix) - 4: [ - { - "name": "lm_pp3_4gpu", - "encoder_tp": 1, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, - "llm_tp": 1, "llm_pp": 3, "llm_dp": 1, "llm_offset": 1, - "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, - "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, - }, - ], - # 8 GPUs: Multiple configurations - 8: [ - # Config 1: Encoder TP=2 PP=1, LLM TP=2 PP=3 (heterogeneous) - # Encoder: 2 ranks (0-1), LLM: 6 ranks (2-7) - # num_layers must be divisible by pp, so use 3 - { - "name": "encoder_tp2_llm_tp2_pp3_8gpu", - "encoder_tp": 2, "encoder_pp": 1, "encoder_dp": 1, "encoder_offset": 0, - "llm_tp": 2, "llm_pp": 3, "llm_dp": 1, "llm_offset": 2, - "hidden_size": 256, "num_layers": 3, "vocab_size": 1000, - "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, - }, - # Config 2: Encoder PP=2, LLM PP=2 with TP=2 each - # Encoder: 4 ranks (0-3), LLM: 4 ranks (4-7) - { - "name": "full_pp_8gpu", - "encoder_tp": 2, "encoder_pp": 2, "encoder_dp": 1, "encoder_offset": 0, - "llm_tp": 2, "llm_pp": 2, "llm_dp": 1, "llm_offset": 4, - "hidden_size": 256, "num_layers": 2, "vocab_size": 1000, - "seq_length": 64, "micro_batch_size": 2, "num_microbatches": 4, - }, - ], - } - - -def main(): - """Main entry point.""" - import argparse - - parser = argparse.ArgumentParser(description="MIMO 1F1B Schedule Test") - parser.add_argument("--config", type=str, default=None, - help="Specific config name to run (e.g., 'baseline_2gpu')") - parser.add_argument("--list-configs", action="store_true", - help="List available configurations and exit") - args = parser.parse_args() - - # List configs if requested - if args.list_configs: - configs = get_test_configs() - print("Available configurations:") - for world_size, config_list in configs.items(): - print(f"\n {world_size} GPUs:") - for cfg in config_list: - print(f" - {cfg['name']}") - return - - # Initialize distributed - dist.init_process_group(backend="nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - torch.cuda.set_device(rank) - - logger.info(f"Rank {rank}/{world_size} initialized") - - configs = get_test_configs() - - if world_size not in configs: - logger.error(f"No configurations for world_size={world_size}. Available: {list(configs.keys())}") - dist.destroy_process_group() - return - - # Filter configs if specific one requested - test_configs = configs[world_size] - if args.config: - test_configs = [c for c in test_configs if c["name"] == args.config] - if not test_configs: - logger.error(f"Config '{args.config}' not found for {world_size} GPUs") - dist.destroy_process_group() - return - - # Run all matching configs - for config in test_configs: - name = config.pop("name") - logger.info(f"Running test: {name}") - try: - run_mimo_1f1b_test(**config) - logger.info(f"Test {name} PASSED") - except Exception as e: - logger.error(f"Test {name} FAILED: {e}") - raise - finally: - config["name"] = name # Restore for potential reuse - - dist.destroy_process_group() - logger.info("All tests completed successfully!") - - -if __name__ == "__main__": - main() +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh requires PyTorch 2.3+", +) +class TestMimo1F1BSchedule: + """Test MIMO model with 1F1B pipeline schedule.""" + + @classmethod + def setup_class(cls): + Utils.initialize_distributed() + cls.world_size = dist.get_world_size() + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def teardown_method(self): + destroy_all_grids() + + def test_baseline_2gpu(self): + """Encoder PP=1, LLM PP=1 on 2 GPUs.""" + if self.world_size != 2: + pytest.skip(f"Requires 2 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=1, + llm_dp=1, + llm_offset=1, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_lm_pp3_4gpu(self): + """Encoder PP=1, LLM PP=3 on 4 GPUs.""" + if self.world_size != 4: + pytest.skip(f"Requires 4 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=3, + llm_dp=1, + llm_offset=1, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_encoder_tp2_llm_tp2_pp3_8gpu(self): + """Encoder TP=2 PP=1, LLM TP=2 PP=3 on 8 GPUs.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=2, + encoder_pp=1, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=3, + llm_dp=1, + llm_offset=2, + hidden_size=256, + num_layers=3, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_full_pp_8gpu(self): + """Encoder PP=2, LLM PP=2 with TP=2 each on 8 GPUs.""" + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=2, + encoder_pp=2, + encoder_dp=1, + encoder_offset=0, + llm_tp=2, + llm_pp=2, + llm_dp=1, + llm_offset=4, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) From 1335de60b62f4567f6831468ca2fd4dfa291f802 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Sun, 22 Mar 2026 22:29:46 +0000 Subject: [PATCH 27/30] Fix stale MiMo tests for multi-rank execution --- tests/unit_tests/models/test_mimo_model.py | 10 ++++++---- tests/unit_tests/models/test_mimo_role.py | 3 --- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/models/test_mimo_model.py b/tests/unit_tests/models/test_mimo_model.py index 1a5f372832c..e1c4b6e89bf 100644 --- a/tests/unit_tests/models/test_mimo_model.py +++ b/tests/unit_tests/models/test_mimo_model.py @@ -9,6 +9,7 @@ import pytest import torch +import torch.distributed as dist import torch.nn as nn from transformers import WhisperConfig, WhisperModel @@ -492,8 +493,9 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p self.hidden_size, self.img_h, self.img_w, self.patch_dim ) - encoder_offset = 0 if encoder_in_grid else 10 - language_offset = 0 if language_in_grid else 10 + world_size = dist.get_world_size() + encoder_offset = 0 if encoder_in_grid else world_size + language_offset = 0 if language_in_grid else world_size return MimoModelConfig( language_model_spec=language_model_spec, @@ -502,14 +504,14 @@ def _make_config(self, encoder_in_grid=True, language_in_grid=True, pp_rank=0, p module_to_grid_map={ "images": MockGrid( rank_offset=encoder_offset, - size=1, + size=world_size, dim_names=["pp"] if pp_size > 1 else [], pp_rank=pp_rank, pp_size=pp_size, ), MIMO_LANGUAGE_MODULE_KEY: MockGrid( rank_offset=language_offset, - size=1, + size=world_size, dim_names=["pp"] if pp_size > 1 else [], pp_rank=pp_rank, pp_size=pp_size, diff --git a/tests/unit_tests/models/test_mimo_role.py b/tests/unit_tests/models/test_mimo_role.py index 28f2c5cae54..e67bf4d712f 100644 --- a/tests/unit_tests/models/test_mimo_role.py +++ b/tests/unit_tests/models/test_mimo_role.py @@ -25,7 +25,6 @@ def test_rank_role(self): # Encoder-only role encoder_role = RankRole( modules={"vision": ModuleStageInfo(True, False)}, - language_module_name="language", ) assert encoder_role.has_modality_modules is True assert encoder_role.has_language_module is False @@ -34,7 +33,6 @@ def test_rank_role(self): # Language-only role lang_role = RankRole( modules={"language": ModuleStageInfo(True, True)}, - language_module_name="language", ) assert lang_role.has_modality_modules is False assert lang_role.has_language_module is True @@ -45,7 +43,6 @@ def test_rank_role(self): "vision": ModuleStageInfo(is_first_stage=True, is_last_stage=False), "language": ModuleStageInfo(is_first_stage=False, is_last_stage=True), }, - language_module_name="language", ) assert mixed.is_first_stage("vision") is True assert mixed.is_last_stage("vision") is False From 0ad8abecb948bf088910c31afba11e810e2fb0ff Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Mon, 23 Mar 2026 00:03:46 +0000 Subject: [PATCH 28/30] Fix post-merge compatibility: use MIMO_LANGUAGE_MODULE_KEY constant After merging PR3211 into PR3212, the optimizer code referenced the removed `config.language_module_key` attribute. Update to use the constant `MIMO_LANGUAGE_MODULE_KEY` from role.py, remove the stale kwarg from test_mimo_optimizer.py, and add composite process groups (tp-pp, tp-ep-pp, dp-ep) required by the optimizer to the 1F1B test grid setup. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/optimizer.py | 4 +++- tests/unit_tests/models/test_mimo_1f1b_schedule.py | 3 +++ tests/unit_tests/models/test_mimo_optimizer.py | 1 - 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index f8ccf58abbf..0b2c2575d8f 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -192,7 +192,9 @@ def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> Mimo from megatron.core.optimizer import get_megatron_optimizer grid_map = mimo_model.mimo_config.module_to_grid_map - lang_key = mimo_model.mimo_config.language_module_key + from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + + lang_key = MIMO_LANGUAGE_MODULE_KEY module_infos: Dict[str, ModuleOptimizerInfo] = {} diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 59251ac3724..099e28e9395 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -75,6 +75,9 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): grid.create_pg(["dp", "cp"]) grid.create_pg(["ep"]) grid.create_pg(["expt_dp"]) + grid.create_pg(["tp", "pp"]) + grid.create_pg(["tp", "ep", "pp"]) + grid.create_pg(["dp", "ep"]) _active_grids.append(grid) return grid diff --git a/tests/unit_tests/models/test_mimo_optimizer.py b/tests/unit_tests/models/test_mimo_optimizer.py index a1d90202814..980c9e72925 100644 --- a/tests/unit_tests/models/test_mimo_optimizer.py +++ b/tests/unit_tests/models/test_mimo_optimizer.py @@ -213,7 +213,6 @@ def get_pg_collection(grid): modality_submodules_spec={"images": vision_spec}, special_token_ids={"images": 50257}, module_to_grid_map={"images": encoder_grid, "language": llm_grid}, - language_module_key="language", ) mimo_model = MimoModel(mimo_config) From d3600b336c43bbaf10611e688bd02cf079f353d8 Mon Sep 17 00:00:00 2001 From: Yashaswi Karnati Date: Mon, 23 Mar 2026 00:41:42 +0000 Subject: [PATCH 29/30] Fix numpy.bool_ type issue and invalid 4-GPU test configuration Two fixes for the MIMO non-colocated pipeline tests: 1. HyperCommGrid.is_current_rank_in_grid() returned numpy.bool_ (from np.prod) instead of Python bool, causing `is True` checks to fail in the distributed optimizer test. 2. test_lm_pp3_4gpu used num_layers=2 with llm_pp=3, violating the Megatron constraint that num_layers must be divisible by pp_size. This caused an assertion failure on LLM ranks while the encoder rank waited at a barrier, appearing as a deadlock. Fixed by changing num_layers to 3. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/hyper_comm_grid.py | 2 +- tests/unit_tests/models/test_mimo_1f1b_schedule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/hyper_comm_grid.py b/megatron/core/hyper_comm_grid.py index 48787d5fa5b..4b860396c4e 100644 --- a/megatron/core/hyper_comm_grid.py +++ b/megatron/core/hyper_comm_grid.py @@ -270,4 +270,4 @@ def is_current_rank_in_grid(self) -> bool: True if the current rank is within [rank_offset, rank_offset + size). """ rank = dist.get_rank() - return self.rank_offset <= rank < self.rank_offset + self.size + return bool(self.rank_offset <= rank < self.rank_offset + self.size) diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 099e28e9395..b6fd9b09a1d 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -668,7 +668,7 @@ def test_lm_pp3_4gpu(self): llm_dp=1, llm_offset=1, hidden_size=256, - num_layers=2, + num_layers=3, vocab_size=1000, seq_length=64, micro_batch_size=2, From 5db7db39e96f78dfdbeb765231213e561f163daf Mon Sep 17 00:00:00 2001 From: ykarnati Date: Tue, 24 Mar 2026 10:19:54 -0700 Subject: [PATCH 30/30] Fix grad norm computation for MIMO DistributedOptimizer The intra_dist_opt group was set to ["dp", "cp"] which only spans data-parallel ranks. This meant the grad norm all-reduce in get_grad_norm_fp32 missed TP/PP/EP ranks that hold different parameter shards, producing an incomplete norm and incorrect gradient clipping. Changed to ["tp", "cp", "ep", "pp", "dp"] (full module world) to match standard Megatron's intra_distributed_optimizer_instance_group which spans all ranks when num_distributed_optimizer_instances == 1. Also added assertion that num_distributed_optimizer_instances == 1, since the MIMO optimizer does not yet support multiple instances. Co-Authored-By: Claude Opus 4.6 (1M context) --- megatron/core/models/mimo/optimizer.py | 21 ++++++++++-- .../models/test_mimo_1f1b_schedule.py | 13 ++++++-- .../unit_tests/models/test_mimo_optimizer.py | 33 ++++++++----------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/megatron/core/models/mimo/optimizer.py b/megatron/core/models/mimo/optimizer.py index 0b2c2575d8f..0845a970649 100644 --- a/megatron/core/models/mimo/optimizer.py +++ b/megatron/core/models/mimo/optimizer.py @@ -110,6 +110,7 @@ def count_zeros(self) -> int: @property def param_groups(self) -> List[dict]: + """Combined param groups from all active module optimizers.""" groups = [] for opt in self._active_optimizers: groups.extend(opt.param_groups) @@ -155,6 +156,7 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: grid.create_pg(["tp", "pp"]) grid.create_pg(["tp", "ep", "pp"]) grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) Args: grid: HyperCommGrid with pre-created process groups. @@ -180,9 +182,12 @@ def _get_pg_collection_for_optimizer(grid) -> ProcessGroupCollection: pg.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) pg.expt_dp = grid.get_pg(["dp", "ep"]) - # Distributed optimizer group (same as dp_cp when num_distributed_optimizer_instances == 1) - # FIXME: Yash - handle multiple optimizer instances - pg.intra_dist_opt = grid.get_pg(["dp", "cp"]) + # Distributed optimizer grad stats group: must span all dimensions so grad norm + # and found-inf all-reduces see every unique gradient shard. TP/PP/EP ranks hold + # different parameters, DP ranks hold different optimizer shards after reduce-scatter. + # This mirrors standard Megatron's intra_distributed_optimizer_instance_group which + # spans the full world when num_distributed_optimizer_instances == 1. + pg.intra_dist_opt = grid.get_pg(["tp", "cp", "ep", "pp", "dp"]) return pg @@ -211,6 +216,16 @@ def get_mimo_optimizer(mimo_model: "MimoModel", config: OptimizerConfig) -> Mimo module = mimo_model.modality_submodules[module_name] if module is not None: + assert ( + not hasattr(module, 'ddp_config') + or module.ddp_config is None + or module.ddp_config.num_distributed_optimizer_instances == 1 + ), ( + "MIMO optimizer does not yet support " + "num_distributed_optimizer_instances > 1. " + f"Module '{module_name}' has " + f"{module.ddp_config.num_distributed_optimizer_instances} instances." + ) optimizer = get_megatron_optimizer( config=config, model_chunks=[module], diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index b6fd9b09a1d..ea81161f844 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -78,6 +78,7 @@ def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): grid.create_pg(["tp", "pp"]) grid.create_pg(["tp", "ep", "pp"]) grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) _active_grids.append(grid) return grid @@ -521,7 +522,9 @@ def finalize_grads_func(*args, **kwargs): use_distributed_optimizer=True, ) optimizer = get_mimo_optimizer(mimo_model, opt_config) - logger.info(f"[Rank {dist.get_rank()}] MimoOptimizer created with {len(optimizer._active_optimizers)} active optimizers") + logger.info( + f"[Rank {dist.get_rank()}] MimoOptimizer created with {len(optimizer._active_optimizers)} active optimizers" + ) logger.info(f"[Rank {dist.get_rank()}] Creating communicator...") communicator = MultiModulePipelineCommunicator( @@ -576,7 +579,9 @@ def loss_func(loss_mask, output_tensor): output_tensor, loss_mask = model(**batch) return output_tensor, partial(loss_func, loss_mask) - logger.info(f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches...") + logger.info( + f"[Rank {dist.get_rank()}] Running 1F1B schedule with {num_microbatches} microbatches..." + ) # Zero gradients before forward/backward optimizer.zero_grad() @@ -596,7 +601,9 @@ def loss_func(loss_mask, output_tensor): # Optimizer step with global gradient clipping logger.info(f"[Rank {dist.get_rank()}] Running optimizer step...") success, grad_norm, num_zeros = optimizer.step() - logger.info(f"[Rank {dist.get_rank()}] Optimizer step: success={success}, grad_norm={grad_norm}") + logger.info( + f"[Rank {dist.get_rank()}] Optimizer step: success={success}, grad_norm={grad_norm}" + ) # Verify results on last LLM stage if is_rank_in_grid(llm_grid) and is_pp_last_stage(llm_grid.get_pg("pp")): diff --git a/tests/unit_tests/models/test_mimo_optimizer.py b/tests/unit_tests/models/test_mimo_optimizer.py index 980c9e72925..70d683f6606 100644 --- a/tests/unit_tests/models/test_mimo_optimizer.py +++ b/tests/unit_tests/models/test_mimo_optimizer.py @@ -19,21 +19,11 @@ class TestModuleOptimizerInfo: """Tests for ModuleOptimizerInfo dataclass.""" def test_create_active(self): - info = ModuleOptimizerInfo( - optimizer=None, - grid=None, - pg_collection=None, - is_active=True, - ) + info = ModuleOptimizerInfo(optimizer=None, grid=None, pg_collection=None, is_active=True) assert info.is_active is True def test_create_inactive(self): - info = ModuleOptimizerInfo( - optimizer=None, - grid=None, - pg_collection=None, - is_active=False, - ) + info = ModuleOptimizerInfo(optimizer=None, grid=None, pg_collection=None, is_active=False) assert info.is_active is False @@ -69,9 +59,7 @@ def test_state_dict_empty(self): from megatron.core.optimizer.optimizer_config import OptimizerConfig config = OptimizerConfig(optimizer='adam', lr=1e-4) - module_infos = { - "encoder": ModuleOptimizerInfo(None, None, None, is_active=False), - } + module_infos = {"encoder": ModuleOptimizerInfo(None, None, None, is_active=False)} opt = MimoOptimizer(module_infos, config) state = opt.state_dict() @@ -83,6 +71,7 @@ def test_state_dict_empty(self): # Integration tests (require torchrun) # ============================================================================ + def run_distributed_test(): """Run distributed integration test.""" import torch.distributed as dist @@ -117,6 +106,7 @@ def create_grid(offset=0, tp=1, pp=1, dp=1): grid.create_pg(["tp", "pp"]) grid.create_pg(["tp", "ep", "pp"]) grid.create_pg(["dp", "ep"]) + grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) return grid def get_pg_collection(grid): @@ -220,9 +210,7 @@ def get_pg_collection(grid): # Wrap with DDP ddp_config = DistributedDataParallelConfig( - overlap_grad_reduce=True, - bucket_size=10000, - use_distributed_optimizer=True, + overlap_grad_reduce=True, bucket_size=10000, use_distributed_optimizer=True ) if mimo_model.language_model is not None: @@ -233,7 +221,10 @@ def get_pg_collection(grid): pg_collection=llm_pg, ) - if "images" in mimo_model.modality_submodules and mimo_model.modality_submodules["images"] is not None: + if ( + "images" in mimo_model.modality_submodules + and mimo_model.modality_submodules["images"] is not None + ): submodule = mimo_model.modality_submodules["images"] mimo_model.modality_submodules["images"] = DistributedDataParallel( config=submodule.encoders['clip'].config, @@ -254,7 +245,9 @@ def get_pg_collection(grid): optimizer = get_mimo_optimizer(mimo_model, opt_config) - print(f"[Rank {rank}] Created optimizer with {len(optimizer._active_optimizers)} active optimizers") + print( + f"[Rank {rank}] Created optimizer with {len(optimizer._active_optimizers)} active optimizers" + ) # Verify structure assert "images" in optimizer.module_infos