diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index 9e28104173..10b3ddd4da 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit 9e2810417315a7ee93b41d4e234454abd3c16af5 +Subproject commit 10b3ddd4da10532b13e467b89c6cedf590e4c026 diff --git a/src/megatron/bridge/data/loaders.py b/src/megatron/bridge/data/loaders.py index 25e17e2723..68eb5daf5c 100644 --- a/src/megatron/bridge/data/loaders.py +++ b/src/megatron/bridge/data/loaders.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader from megatron.bridge.data.samplers import build_pretraining_data_loader -from megatron.bridge.training.config import ConfigContainer, DatasetProvider, GPTDatasetConfig +from megatron.bridge.training.config import ConfigContainer, GPTDatasetConfig from megatron.bridge.training.state import TrainState from megatron.bridge.training.utils.sig_utils import DistributedSignalHandler from megatron.bridge.utils.common_utils import print_rank_0 @@ -178,104 +178,128 @@ def build_train_valid_test_data_loaders( Returns: A tuple (train_dataloader, valid_dataloader, test_dataloader). """ - specialized_builder = _resolve_data_loader_builder(cfg) - if specialized_builder is not None: - train_dataloader, valid_dataloader, test_dataloader = specialized_builder( + # Check for MIMO path + from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider + from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider + + if isinstance(cfg.model, MimoModelProvider): + if not isinstance(cfg.dataset, MimoDatasetProvider): + raise ValueError( + "MIMO models require cfg.dataset to be a MimoDatasetProvider. " + "Use HFMimoDatasetProvider, MockMimoProvider, or a subclass of MimoDatasetProvider." + ) + from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders + + train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg) + train_dataloader, valid_dataloader, test_dataloader = build_mimo_data_loaders( cfg=cfg, train_state=train_state, - build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider, - dp_group=dp_group, + mimo_provider=cfg.dataset, + train_samples=train_samples, + valid_samples=valid_samples, + test_samples=test_samples, ) - else: - (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) - - print_rank_0("> building train, validation, and test datasets ...") - # Construct the data pipeline - # Build datasets. - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - cfg=cfg, build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider - ) + # Sync train_state flags across all ranks. + # Use all_reduce(MAX) since some ranks may not have loaders in heterogeneous MIMO. + do_train = train_dataloader is not None and cfg.train.train_iters > 0 + do_valid = valid_dataloader is not None and cfg.validation.eval_iters > 0 + do_test = test_dataloader is not None and cfg.validation.eval_iters > 0 + flags = torch.tensor([int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device="cuda") + torch.distributed.all_reduce(flags, op=torch.distributed.ReduceOp.MAX) + train_state.do_train = flags[0].item() + train_state.do_valid = flags[1].item() + train_state.do_test = flags[2].item() - exit_signal = cfg.train.exit_signal + return train_dataloader, valid_dataloader, test_dataloader - def worker_init_fn(_): - DistributedSignalHandler(exit_signal).__enter__() + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) - maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None + print_rank_0("> building train, validation, and test datasets ...") - # Resolve DP rank/size from provided data-parallel process group - dp_rank = torch.distributed.get_rank(group=dp_group) - dp_size = torch.distributed.get_world_size(group=dp_group) + # Construct the data pipeline + # Build datasets. + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + cfg=cfg, build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider + ) - # Build dataloders. - train_dataloader = build_pretraining_data_loader( - train_ds, - train_state.consumed_train_samples, + exit_signal = cfg.train.exit_signal + + def worker_init_fn(_): + DistributedSignalHandler(exit_signal).__enter__() + + maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None + + # Resolve DP rank/size from provided data-parallel process group + dp_rank = torch.distributed.get_rank(group=dp_group) + dp_size = torch.distributed.get_world_size(group=dp_group) + + # Build dataloders. + train_dataloader = build_pretraining_data_loader( + train_ds, + train_state.consumed_train_samples, + cfg.dataset.dataloader_type, + cfg.train.micro_batch_size, + cfg.dataset.num_workers, + cfg.dataset.data_sharding, + worker_init_fn=maybe_worker_init_fn, + collate_fn=train_ds.collate_fn if hasattr(train_ds, "collate_fn") else None, + pin_memory=cfg.dataset.pin_memory, + persistent_workers=cfg.dataset.persistent_workers, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + global_batch_size=cfg.train.global_batch_size, + ) + if cfg.validation.skip_train and cfg.validation.eval_iters > 0: + valid_dataloader = build_pretraining_data_loader( + valid_ds, + 0, cfg.dataset.dataloader_type, cfg.train.micro_batch_size, cfg.dataset.num_workers, cfg.dataset.data_sharding, worker_init_fn=maybe_worker_init_fn, - collate_fn=train_ds.collate_fn if hasattr(train_ds, "collate_fn") else None, + collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None, + pin_memory=cfg.dataset.pin_memory, + persistent_workers=cfg.dataset.persistent_workers, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + global_batch_size=cfg.train.global_batch_size, + ) + elif cfg.validation.eval_iters > 0: + val_dataloader_type = "cyclic" if isinstance(cfg.dataset, GPTDatasetConfig) else cfg.dataset.dataloader_type + valid_dataloader = build_pretraining_data_loader( + valid_ds, + train_state.consumed_valid_samples, + val_dataloader_type, + cfg.train.micro_batch_size, + cfg.dataset.num_workers, + cfg.dataset.data_sharding, + worker_init_fn=maybe_worker_init_fn, + collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None, pin_memory=cfg.dataset.pin_memory, persistent_workers=cfg.dataset.persistent_workers, data_parallel_rank=dp_rank, data_parallel_size=dp_size, global_batch_size=cfg.train.global_batch_size, ) - if cfg.validation.skip_train and cfg.validation.eval_iters > 0: - valid_dataloader = build_pretraining_data_loader( - valid_ds, - 0, - cfg.dataset.dataloader_type, - cfg.train.micro_batch_size, - cfg.dataset.num_workers, - cfg.dataset.data_sharding, - worker_init_fn=maybe_worker_init_fn, - collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None, - pin_memory=cfg.dataset.pin_memory, - persistent_workers=cfg.dataset.persistent_workers, - data_parallel_rank=dp_rank, - data_parallel_size=dp_size, - global_batch_size=cfg.train.global_batch_size, - ) - elif cfg.validation.eval_iters > 0: - val_dataloader_type = ( - "cyclic" if isinstance(cfg.dataset, GPTDatasetConfig) else cfg.dataset.dataloader_type - ) - valid_dataloader = build_pretraining_data_loader( - valid_ds, - train_state.consumed_valid_samples, - val_dataloader_type, - cfg.train.micro_batch_size, - cfg.dataset.num_workers, - cfg.dataset.data_sharding, - worker_init_fn=maybe_worker_init_fn, - collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None, - pin_memory=cfg.dataset.pin_memory, - persistent_workers=cfg.dataset.persistent_workers, - data_parallel_rank=dp_rank, - data_parallel_size=dp_size, - global_batch_size=cfg.train.global_batch_size, - ) - if cfg.validation.eval_iters > 0: - test_dataloader = build_pretraining_data_loader( - test_ds, - 0, - cfg.dataset.dataloader_type, - cfg.train.micro_batch_size, - cfg.dataset.num_workers, - cfg.dataset.data_sharding, - worker_init_fn=maybe_worker_init_fn, - collate_fn=test_ds.collate_fn if hasattr(test_ds, "collate_fn") else None, - pin_memory=cfg.dataset.pin_memory, - persistent_workers=cfg.dataset.persistent_workers, - data_parallel_rank=dp_rank, - data_parallel_size=dp_size, - global_batch_size=cfg.train.global_batch_size, - ) + if cfg.validation.eval_iters > 0: + test_dataloader = build_pretraining_data_loader( + test_ds, + 0, + cfg.dataset.dataloader_type, + cfg.train.micro_batch_size, + cfg.dataset.num_workers, + cfg.dataset.data_sharding, + worker_init_fn=maybe_worker_init_fn, + collate_fn=test_ds.collate_fn if hasattr(test_ds, "collate_fn") else None, + pin_memory=cfg.dataset.pin_memory, + persistent_workers=cfg.dataset.persistent_workers, + data_parallel_rank=dp_rank, + data_parallel_size=dp_size, + global_batch_size=cfg.train.global_batch_size, + ) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and cfg.train.train_iters > 0 @@ -292,58 +316,6 @@ def worker_init_fn(_): return train_dataloader, valid_dataloader, test_dataloader -def _resolve_data_loader_builder( - cfg: ConfigContainer, -) -> Optional[Callable[..., tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]]]: - """Resolve an optional model-specific data loader builder. - - This acts as a lightweight dispatch registry so model-specific loader logic - does not need to be hard-coded inline in the generic data loader path. - """ - from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider - - specialized_builders: dict[ - type, - Callable[..., tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]], - ] = { - MimoModelProvider: _build_mimo_train_valid_test_data_loaders, - } - - for model_type, builder in specialized_builders.items(): - if isinstance(cfg.model, model_type): - return builder - return None - - -def _build_mimo_train_valid_test_data_loaders( - cfg: ConfigContainer, - train_state: TrainState, - build_train_valid_test_datasets_provider: Callable, # Unused; kept for common builder signature. - dp_group: torch.distributed.ProcessGroup, # Unused; MIMO determines DP per module. -) -> tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]: - """Build train/valid/test loaders for MIMO models via the specialized MIMO path.""" - del build_train_valid_test_datasets_provider, dp_group - - if not isinstance(cfg.dataset, DatasetProvider) or not callable(getattr(cfg.dataset, "get_collate_fn", None)): - raise ValueError( - "MIMO models require cfg.dataset to implement DatasetProvider.build_datasets() " - "and a MIMO-compatible get_collate_fn() method (e.g., HFMimoDatasetProvider " - "or MockMimoProvider)." - ) - - from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders - - train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg) - return build_mimo_data_loaders( - cfg=cfg, - train_state=train_state, - mimo_provider=cfg.dataset, - train_samples=train_samples, - valid_samples=valid_samples, - test_samples=test_samples, - ) - - def build_train_valid_test_data_iterators( cfg: ConfigContainer, train_state: TrainState, diff --git a/src/megatron/bridge/data/mimo/__init__.py b/src/megatron/bridge/data/mimo/__init__.py index 70aeaec909..5408cb0802 100644 --- a/src/megatron/bridge/data/mimo/__init__.py +++ b/src/megatron/bridge/data/mimo/__init__.py @@ -1,7 +1,8 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """MIMO multi-encoder data loading utilities.""" # Providers +from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider from megatron.bridge.data.mimo.collate import mimo_collate_fn from megatron.bridge.data.mimo.dataset import MimoDataset from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info @@ -14,7 +15,8 @@ # Core "MimoDataset", "mimo_collate_fn", - # Providers + # Providers (base + implementations) + "MimoDatasetProvider", "HFMimoDatasetProvider", "MockMimoProvider", # Utilities diff --git a/src/megatron/bridge/data/mimo/base_provider.py b/src/megatron/bridge/data/mimo/base_provider.py new file mode 100644 index 0000000000..ad25110d98 --- /dev/null +++ b/src/megatron/bridge/data/mimo/base_provider.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Base class for MIMO dataset providers.""" + +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +from torch.utils.data import Dataset + +from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider + + +@dataclass(kw_only=True) +class MimoDatasetProvider(DatasetProvider): + """Abstract base class for MIMO dataset providers. + + All MIMO dataset providers must inherit from this class and implement + the required methods. This ensures a consistent interface for MIMO + data loading. + + Required methods: + - build_datasets: Build train/valid/test datasets + - get_collate_fn: Return the collate function for batching + + Example: + >>> class MyMimoProvider(MimoDatasetProvider): + ... def build_datasets(self, context): + ... # Build and return datasets + ... return train_ds, valid_ds, test_ds + ... + ... def get_collate_fn(self): + ... # Return collate function + ... return my_collate_fn + """ + + @abstractmethod + def build_datasets( + self, context: DatasetBuildContext + ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: + """Build train, validation, and test datasets. + + Args: + context: Build context with sample counts. + + Returns: + Tuple of (train_dataset, valid_dataset, test_dataset). + Any element can be None if not needed. + """ + ... + + @abstractmethod + def get_collate_fn(self) -> Callable: + """Return the collate function for batching. + + The collate function should handle the modality_inputs dict + and batch them appropriately for the model. + + Returns: + Callable that takes a list of samples and returns a batch dict. + """ + ... diff --git a/src/megatron/bridge/data/mimo/collate.py b/src/megatron/bridge/data/mimo/collate.py index 64938fd55b..6d7d93a6a0 100644 --- a/src/megatron/bridge/data/mimo/collate.py +++ b/src/megatron/bridge/data/mimo/collate.py @@ -22,6 +22,7 @@ def mimo_collate_fn( batch: List of examples from MimoDataset, each containing: - input_ids: Token IDs with placeholder tokens - labels: Labels for causal LM training + - loss_mask: Per-token loss mask - attention_mask: Attention mask - position_ids: Position indices - modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs @@ -31,6 +32,7 @@ def mimo_collate_fn( Dict containing: - input_ids: (batch, seq) stacked token IDs - labels: (batch, seq) stacked labels + - loss_mask: (batch, seq) stacked per-token loss mask - attention_mask: (batch, seq) attention mask - position_ids: (batch, seq) position indices - modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors @@ -59,6 +61,7 @@ def mimo_collate_fn( labels = torch.stack([item["labels"] for item in batch]) attention_mask = torch.stack([item["attention_mask"] for item in batch]) position_ids = torch.stack([item["position_ids"] for item in batch]) + loss_mask = torch.stack([item["loss_mask"] for item in batch]) # Collate modality inputs modality_inputs: Dict[str, Dict[str, Any]] = {} @@ -110,6 +113,7 @@ def mimo_collate_fn( return { "input_ids": input_ids, "labels": labels, + "loss_mask": loss_mask, "attention_mask": attention_mask, "position_ids": position_ids, "modality_inputs": modality_inputs, diff --git a/src/megatron/bridge/data/mimo/dataset.py b/src/megatron/bridge/data/mimo/dataset.py index 68a2d74186..f1264ed5f9 100644 --- a/src/megatron/bridge/data/mimo/dataset.py +++ b/src/megatron/bridge/data/mimo/dataset.py @@ -117,7 +117,8 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: Returns: Dict containing: - input_ids: Tokenized text with placeholder tokens - - labels: Same as input_ids (for causal LM training) + - labels: Shifted input_ids for next-token prediction (-100 for masked positions) + - loss_mask: Float mask (0.0 for padding/image placeholder targets, 1.0 otherwise) - attention_mask: Attention mask - position_ids: Position indices - modality_inputs: Dict[str, Any] with preprocessed inputs per modality @@ -158,9 +159,29 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: attention_mask = torch.ones_like(input_ids) position_ids = torch.arange(len(input_ids)) + # Shift labels by 1 for next-token prediction: label[i] = input_ids[i+1] + labels = input_ids.clone() + labels[:-1] = input_ids[1:] + labels[-1] = -100 # ignore index for the last position + + # Build loss_mask: no loss on padding or encoder placeholder token positions + pad_token_id = self.tokenizer.pad_token_id or 0 + placeholder_ids = set(self.special_token_ids.values()) + + # loss_mask[i] = 0 when the target (labels[i]) is padding or a placeholder + loss_mask = torch.ones_like(input_ids, dtype=torch.float32) + loss_mask[-1] = 0.0 # last position has no valid target + for pid in placeholder_ids: + loss_mask[labels == pid] = 0.0 + loss_mask[labels == pad_token_id] = 0.0 + + # Also mask labels with -100 so CrossEntropyLoss ignores them + labels[loss_mask == 0.0] = -100 + return { "input_ids": input_ids, - "labels": input_ids.clone(), + "labels": labels, + "loss_mask": loss_mask, "attention_mask": attention_mask, "position_ids": position_ids, "modality_inputs": modality_inputs, diff --git a/src/megatron/bridge/data/mimo/dp_utils.py b/src/megatron/bridge/data/mimo/dp_utils.py index 27edd347ce..4ff632e781 100644 --- a/src/megatron/bridge/data/mimo/dp_utils.py +++ b/src/megatron/bridge/data/mimo/dp_utils.py @@ -3,29 +3,22 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Tuple import torch.distributed as dist +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY if TYPE_CHECKING: from megatron.core.hyper_comm_grid import HyperCommGrid - -@dataclass(frozen=True) -class MimoDpInfo: - """Data-parallel loader metadata for the current rank in MIMO training.""" - - dp_rank: int - dp_size: int - needs_data: bool - loader_module: str + from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig def get_mimo_dp_info( + mimo_cfg: "MimoParallelismConfig", grids: Dict[str, "HyperCommGrid"], -) -> MimoDpInfo: +) -> Tuple[int, int, bool, str]: """Get DP rank, size, data-loading responsibility, and loader module for MIMO. Determines which module's DP settings to use for data loading based on @@ -34,10 +27,11 @@ def get_mimo_dp_info( In heterogeneous mode, each rank uses its own module's DP settings. Args: + mimo_cfg: MIMO parallelism configuration. grids: Module name to HyperCommGrid mapping from build_hypercomm_grids(). Returns: - MimoDpInfo with: + Tuple of (dp_rank, dp_size, needs_data, loader_module): - dp_rank: This rank's position in DP group. - dp_size: Size of DP group for data sharding. - needs_data: Whether this rank needs to load data (first/last PP stage). @@ -46,10 +40,10 @@ def get_mimo_dp_info( Example: >>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids >>> grids = build_hypercomm_grids(mimo_cfg) - >>> dp_info = get_mimo_dp_info(grids) - >>> if dp_info.needs_data: + >>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) + >>> if needs_data: ... # Build data loader with dp_rank and dp_size - ... sampler = DistributedSampler(dataset, num_replicas=dp_info.dp_size, rank=dp_info.dp_rank) + ... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank) """ current_rank = dist.get_rank() @@ -64,7 +58,7 @@ def get_mimo_dp_info( if my_grid is None or my_module is None: # Rank doesn't participate in any module - return MimoDpInfo(dp_rank=0, dp_size=1, needs_data=False, loader_module="llm") + return 0, 1, False, MIMO_LANGUAGE_MODULE_KEY dp_rank = my_grid.get_pg(["dp"]).rank() dp_size = my_grid.get_pg(["dp"]).size() @@ -73,14 +67,9 @@ def get_mimo_dp_info( pp_rank = pp_group.rank() pp_size = pp_group.size() - if my_module == "llm": + if my_module == MIMO_LANGUAGE_MODULE_KEY: needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1) else: needs_data = pp_rank == 0 - return MimoDpInfo( - dp_rank=dp_rank, - dp_size=dp_size, - needs_data=needs_data, - loader_module=my_module, - ) + return dp_rank, dp_size, needs_data, my_module diff --git a/src/megatron/bridge/data/mimo/hf_provider.py b/src/megatron/bridge/data/mimo/hf_provider.py index cf60fa96da..095a437560 100644 --- a/src/megatron/bridge/data/mimo/hf_provider.py +++ b/src/megatron/bridge/data/mimo/hf_provider.py @@ -5,20 +5,21 @@ from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from datasets import load_dataset from torch.utils.data import Dataset from transformers import AutoProcessor, AutoTokenizer +from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider from megatron.bridge.data.mimo.collate import mimo_collate_fn from megatron.bridge.data.mimo.dataset import MimoDataset from megatron.bridge.models.hf_pretrained.utils import is_safe_repo -from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider +from megatron.bridge.training.config import DatasetBuildContext @dataclass(kw_only=True) -class HFMimoDatasetProvider(DatasetProvider): +class HFMimoDatasetProvider(MimoDatasetProvider): """DatasetProvider for MIMO models using HuggingFace datasets. Loads datasets from HuggingFace Hub and applies per-modality processors @@ -72,6 +73,9 @@ class HFMimoDatasetProvider(DatasetProvider): train_split: str = "train" valid_split: str = "validation" test_split: str = "test" + trust_remote_code: Optional[bool] = None + hf_data_files: Optional[Union[str, List[str]]] = None + preprocess_fn: Optional[Callable] = None # Cached processors and tokenizer (loaded once) _processors: Optional[Dict[str, Any]] = field(default=None, repr=False) @@ -123,6 +127,7 @@ def _load_hf_dataset(self, split: str) -> Any: dataset = load_dataset( self.hf_dataset_path, name=self.hf_dataset_name, + data_files=self.hf_data_files, split=split, trust_remote_code=is_safe_repo( trust_remote_code=self.trust_remote_code, @@ -159,6 +164,7 @@ def _build_split_dataset( modality_columns=self.modality_columns, text_column=self.text_column, max_samples=target_samples, + preprocess_fn=self.preprocess_fn, ) def build_datasets( diff --git a/src/megatron/bridge/data/mimo/loaders.py b/src/megatron/bridge/data/mimo/loaders.py index b504a93c7b..9f934c665b 100644 --- a/src/megatron/bridge/data/mimo/loaders.py +++ b/src/megatron/bridge/data/mimo/loaders.py @@ -35,7 +35,7 @@ def build_mimo_data_loaders( Args: cfg: Configuration container with MimoModelProvider as cfg.model. train_state: Current training state. - mimo_provider: MIMO dataset provider (e.g., MockMimoProvider) + mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider) with get_collate_fn() method. train_samples: Number of training samples. valid_samples: Number of validation samples. @@ -70,19 +70,19 @@ def build_mimo_data_loaders( if cfg.model.mimo_parallelism_config is None: raise ValueError("mimo_parallelism_config must be set for MIMO data loading.") + if cfg.model._grids is None: + raise ValueError( + "MimoModelProvider._grids is None. Ensure build_model() is called before building data loaders." + ) + print_rank_0("> building MIMO train, validation, and test datasets ...") - # Reuse cached infrastructure (build once if needed). - infra = cfg.model.get_or_build_infra() - grids = infra.module_to_grid_map - dp_info = get_mimo_dp_info(grids) + # Use cached grids from build_model() + grids = cfg.model._grids - print_rank_0( - f" MIMO DP info: dp_rank={dp_info.dp_rank}, dp_size={dp_info.dp_size}, " - f"needs_data={dp_info.needs_data}, loader_module={dp_info.loader_module}" - ) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) - if not dp_info.needs_data: + if not needs_data: return None, None, None # Build datasets @@ -109,8 +109,8 @@ def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: return None sampler = torch.utils.data.DistributedSampler( dataset, - num_replicas=dp_info.dp_size, - rank=dp_info.dp_rank, + num_replicas=dp_size, + rank=dp_rank, shuffle=shuffle, ) return DataLoader( @@ -123,8 +123,8 @@ def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: drop_last=mimo_provider.drop_last, ) - return ( - _make_loader(train_ds, shuffle=True), - _make_loader(valid_ds, shuffle=False), - _make_loader(test_ds, shuffle=False), - ) + train_loader = _make_loader(train_ds, shuffle=True) + valid_loader = _make_loader(valid_ds, shuffle=False) + test_loader = _make_loader(test_ds, shuffle=False) + + return train_loader, valid_loader, test_loader diff --git a/src/megatron/bridge/data/mimo/mock_provider.py b/src/megatron/bridge/data/mimo/mock_provider.py index 693836eb7e..4f364de941 100644 --- a/src/megatron/bridge/data/mimo/mock_provider.py +++ b/src/megatron/bridge/data/mimo/mock_provider.py @@ -16,9 +16,10 @@ import numpy as np from PIL import Image +from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider from megatron.bridge.data.mimo.dataset import MimoDataset from megatron.bridge.models.hf_pretrained.utils import is_safe_repo -from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider +from megatron.bridge.training.config import DatasetBuildContext def _generate_random_image(width: int, height: int, rng: np.random.Generator) -> Image.Image: @@ -35,7 +36,7 @@ def _generate_random_audio(duration_sec: float, sample_rate: int, rng: np.random @dataclass(kw_only=True) -class MockMimoProvider(DatasetProvider): +class MockMimoProvider(MimoDatasetProvider): """DatasetProvider for mock MIMO datasets with synthetic multimodal data. Generates synthetic multimodal inputs (random images, audio, etc.) and uses @@ -113,6 +114,12 @@ def _load_tokenizer(self) -> Any: if self._tokenizer is not None: return self._tokenizer + if not self.tokenizer_path: + raise ValueError( + "tokenizer_path must be set for MockMimoProvider. " + "Provide a valid HuggingFace tokenizer path (e.g., 'gpt2', 'meta-llama/Llama-2-7b-hf')." + ) + from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( diff --git a/src/megatron/bridge/models/mimo/llava_provider.py b/src/megatron/bridge/models/mimo/llava_provider.py index b4790b3532..c4fcf6d25e 100644 --- a/src/megatron/bridge/models/mimo/llava_provider.py +++ b/src/megatron/bridge/models/mimo/llava_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """LLaVA-style Vision-Language Model provider.""" from dataclasses import dataclass, field @@ -46,9 +46,6 @@ class LlavaMimoProvider(MimoModelProvider): # Optional custom configs language_config: Optional[TransformerConfig] = None - # Make parent's required field optional (we build it in __post_init__) - language_model_spec: Optional[ModuleSpec] = None - def __post_init__(self): """Build specs after initialization.""" if self.vision_encoder_module is None: diff --git a/src/megatron/bridge/models/mimo/mimo_builder.py b/src/megatron/bridge/models/mimo/mimo_builder.py index e3f9ad3d45..82648f73bb 100644 --- a/src/megatron/bridge/models/mimo/mimo_builder.py +++ b/src/megatron/bridge/models/mimo/mimo_builder.py @@ -1,8 +1,7 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Tuple import torch.distributed as dist @@ -15,7 +14,7 @@ def build_hypercomm_grids( mimo_parallelism_config: MimoParallelismConfig, -) -> Dict[str, HyperCommGrid]: +) -> Dict[str, "HyperCommGrid"]: """Create HyperCommGrid objects per module from MIMO parallelism config. Creates grids on ALL ranks (required for consistent collective calls), @@ -47,20 +46,18 @@ def build_hypercomm_grids( # Create all standard process groups for dim in ("tp", "cp", "ep", "pp", "dp"): _ = grid.create_pg([dim]) - # Create dp_cp composite group for gradient reduction _ = grid.create_pg(["dp", "cp"]) + _ = grid.create_pg(["tp", "pp"]) + _ = grid.create_pg(["tp", "ep", "pp"]) + _ = grid.create_pg(["dp", "ep"]) + _ = grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) grids[module_name] = grid return grids -def _default_topology(mimo_parallelism_config: MimoParallelismConfig) -> Dict[str, List[str]]: - """Infer a default multi-encoder -> LLM topology.""" - return {name: ["llm"] for name in mimo_parallelism_config.module_names if name != "llm"} | {"llm": []} - - -def create_embedding_and_position_groups( +def populate_embedding_and_position_groups( pp_group: dist.ProcessGroup, ) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: """Create embedding-related process groups from PP group ranks. @@ -72,10 +69,6 @@ def create_embedding_and_position_groups( IMPORTANT: This calls dist.new_group which is a collective operation. Must be called on all ranks that could participate. - Note: VPP (virtual_pipeline_model_parallel_size > 1) is not supported. - With VPP, pp_ranks[0]/pp_ranks[-1] do not reliably identify the stages - that own embeddings. The caller is responsible for asserting VPP is disabled. - Args: pp_group: The pipeline parallel process group. @@ -100,13 +93,17 @@ def create_embedding_and_position_groups( return pos_embd_pg, embd_pg -def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool: - """Check if the current rank participates in this grid. +def is_pp_first_stage(pp_group: Optional[dist.ProcessGroup]) -> bool: + """Check if current rank is first stage in pipeline.""" + if pp_group is None: + return True + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + return dist.get_rank() == pp_ranks[0] - Args: - grid: A HyperCommGrid instance. - Returns: - True if dist.get_rank() is within the grid's rank range. - """ - return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size) +def is_pp_last_stage(pp_group: Optional[dist.ProcessGroup]) -> bool: + """Check if current rank is last stage in pipeline.""" + if pp_group is None: + return True + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + return dist.get_rank() == pp_ranks[-1] diff --git a/src/megatron/bridge/models/mimo/mimo_config.py b/src/megatron/bridge/models/mimo/mimo_config.py index a55977b285..5d895cac55 100644 --- a/src/megatron/bridge/models/mimo/mimo_config.py +++ b/src/megatron/bridge/models/mimo/mimo_config.py @@ -1,11 +1,12 @@ # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - from __future__ import annotations import warnings from dataclasses import dataclass, field from typing import Optional +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + @dataclass class ModuleParallelismConfig: @@ -63,7 +64,7 @@ class MimoParallelismConfig: Note: Phase 1 only supports heterogeneous deployment where each module can have different parallelism configurations and rank offsets. - The LLM module must be named "llm" in module_parallelisms. + The language module must be named MIMO_LANGUAGE_MODULE_KEY ("language") in module_parallelisms. """ module_parallelisms: dict[str, ModuleParallelismConfig] @@ -97,41 +98,65 @@ def _validate_heterogeneous(self) -> None: if cur_start < prev_end: raise ValueError("rank_offset ranges overlap in heterogeneous deployment.") - # Check for gaps between modules (likely misconfiguration) - # Gaps in the middle are errors; leading gaps (rank_offset > 0) are warnings - if ranges: - min_rank = ranges[0][0] # Already sorted by rank_offset - max_rank = ranges[-1][1] - - # Collect all covered ranks - covered_ranks = set() - for parallelism in self.module_parallelisms.values(): - start = parallelism.rank_offset - end = start + parallelism.total_ranks - covered_ranks.update(range(start, end)) - - # Check for gaps between min and max (error - likely misconfiguration) - expected_middle = set(range(min_rank, max_rank)) - gaps_in_middle = expected_middle - covered_ranks - if gaps_in_middle: - raise ValueError( - f"Ranks {sorted(gaps_in_middle)} are not assigned to any module in heterogeneous " - f"deployment. This creates a gap between modules which is not allowed." - ) + def _validate_parallelism_constraints(self) -> None: + """Validate parallelism constraints for cross-module communication. - # Check for leading gap (ranks 0 to min_rank-1 unused) - warning only - if min_rank > 0: - warnings.warn( - f"Ranks {list(range(min_rank))} (before first module) are not assigned to any " - f"module in heterogeneous deployment. These ranks will be idle during training.", - stacklevel=3, + - TP sizes must be powers of 2 + - DP sizes must be pairwise divisible (one divides the other) + """ + + def is_power_of_two(n: int) -> bool: + return n > 0 and (n & (n - 1)) == 0 + + # Validate TP is power of 2 + for name, p in self.module_parallelisms.items(): + tp = p.tensor_model_parallel_size + if not is_power_of_two(tp): + raise ValueError( + f"Module '{name}' has TP={tp}, but TP size must be a power of 2 " + f"(1, 2, 4, 8, ...) for cross-module communication compatibility." ) - def finalize(self, world_size: Optional[int]) -> None: - """Finalize parallelism config: compute data_parallel_size and validate.""" - if "llm" not in self.module_parallelisms: + # Validate DP sizes are pairwise divisible + module_names = list(self.module_parallelisms.keys()) + for i, name1 in enumerate(module_names): + for name2 in module_names[i + 1 :]: + dp1 = self.module_parallelisms[name1].data_parallel_size + dp2 = self.module_parallelisms[name2].data_parallel_size + if dp1 is None or dp2 is None: + continue + if dp1 % dp2 != 0 and dp2 % dp1 != 0: + raise ValueError( + f"DP sizes must be divisible between modules. " + f"Module '{name1}' has DP={dp1}, module '{name2}' has DP={dp2}. " + f"One must divide the other for BridgeCommunicator." + ) + + # Validate encoder DP >= LLM DP for embedding alignment + # Encoder modules produce embeddings consumed by LLM. If encoder DP < LLM DP, + # the same encoder batch would need to align with different LLM batches, which fails. + llm_dp = self.module_parallelisms[MIMO_LANGUAGE_MODULE_KEY].data_parallel_size + if llm_dp is not None: + for name, p in self.module_parallelisms.items(): + if name == MIMO_LANGUAGE_MODULE_KEY: + continue + encoder_dp = p.data_parallel_size + if encoder_dp is not None and encoder_dp < llm_dp: + raise ValueError( + f"Encoder module '{name}' has DP={encoder_dp} < LLM DP={llm_dp}. " + f"Encoder DP must be >= LLM DP for embedding alignment across batches." + ) + + def finalize(self, world_size: int) -> None: + """Finalize parallelism config: compute data_parallel_size and validate. + + Args: + world_size: Total number of ranks in the distributed world. + MIMO requires a distributed environment, so this must always be provided. + """ + if MIMO_LANGUAGE_MODULE_KEY not in self.module_parallelisms: raise ValueError( - f"LLM module 'llm' must be in module_parallelisms. " + f"Language module '{MIMO_LANGUAGE_MODULE_KEY}' must be in module_parallelisms. " f"Found modules: {list(self.module_parallelisms.keys())}" ) @@ -140,8 +165,8 @@ def finalize(self, world_size: Optional[int]) -> None: parallelism.finalize(None) self._validate_heterogeneous() + self._validate_parallelism_constraints() - if world_size and world_size > 1: - expected = self.total_world_size - if expected and world_size != expected: - raise ValueError(f"MIMO world size mismatch: expected {expected}, got {world_size}.") + expected = self.total_world_size + if expected and world_size != expected: + raise ValueError(f"MIMO world size mismatch: expected {expected}, got {world_size}.") diff --git a/src/megatron/bridge/models/mimo/mimo_ddp.py b/src/megatron/bridge/models/mimo/mimo_ddp.py index 450051ed4c..ca091abca9 100644 --- a/src/megatron/bridge/models/mimo/mimo_ddp.py +++ b/src/megatron/bridge/models/mimo/mimo_ddp.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Dict, Optional -from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY if TYPE_CHECKING: @@ -45,11 +45,14 @@ def wrap_mimo_model_distributed( """ from megatron.core.distributed import DistributedDataParallel + # Lazy import to avoid circular dependency (models layer loads before training layer) + from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid + # Wrap language model if present and rank participates if mimo_model.language_model is not None: - llm_grid = grids["llm"] - if is_current_rank_in_grid(llm_grid): - llm_pg = pg_collections.get("llm") + llm_grid = grids.get(MIMO_LANGUAGE_MODULE_KEY) + if llm_grid is not None and is_current_rank_in_grid(llm_grid): + llm_pg = pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if llm_pg is not None: mimo_model.language_model = DistributedDataParallel( config=mimo_model.language_model.config, @@ -63,7 +66,9 @@ def wrap_mimo_model_distributed( for module_name, submodule in mimo_model.modality_submodules.items(): if submodule is None: continue - module_grid = grids[module_name] + module_grid = grids.get(module_name) + if module_grid is None: + continue if not is_current_rank_in_grid(module_grid): continue diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index e9856005aa..d337498d9f 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -21,16 +21,17 @@ from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.models.mimo import MimoModel from megatron.core.models.mimo.config.base_configs import MimoModelConfig -from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import get_model_config from megatron.bridge.models.mimo.mimo_builder import ( - _default_topology, build_hypercomm_grids, - create_embedding_and_position_groups, + is_pp_first_stage, + is_pp_last_stage, + populate_embedding_and_position_groups, ) from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed @@ -60,6 +61,7 @@ class MimoModelInfra: topology: Dict[str, List[str]] pg_collections: Dict[str, Optional[ProcessGroupCollection]] participating_modules: List[str] + module_output_ndim: Dict[str, int] = field(default_factory=dict) @dataclass @@ -83,7 +85,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): Example: >>> mimo_parallelism_config = MimoParallelismConfig( ... module_parallelisms={ - ... "llm": ModuleParallelismConfig(tensor_model_parallel_size=8), + ... "language": ModuleParallelismConfig(tensor_model_parallel_size=8), ... "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), ... } ... ) @@ -99,16 +101,26 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): >>> infra = provider.build_infra() """ - # Model specs (user provides, like llava_vlm.py example) - language_model_spec: ModuleSpec + # Model specs (user provides, like llava_vlm.py example). + # Optional so subclasses (e.g. LlavaMimoProvider) can build it in __post_init__. + language_model_spec: Optional[ModuleSpec] = None modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) special_token_ids: Dict[str, int] = field(default_factory=dict) - # Parallelism config (Bridge's value-add) mimo_parallelism_config: Optional[MimoParallelismConfig] = None - # Cached infrastructure for reuse across model/data setup - _cached_infra: Optional[MimoModelInfra] = field(default=None, repr=False) + # Module data-flow DAG for MultiModulePipelineCommunicator. + # If None, auto-derived as: all modality_submodules → language module (terminal). + # Set explicitly for non-standard topologies (e.g., language → generator). + topology: Optional[Dict[str, List[str]]] = None + + # Output tensor dimensionality per module for bridge communicator routing. + # Vision/audio encoders typically produce 2D [S, H]; language modules produce 3D [S, B, H]. + # If None, auto-derived: language module → 3, all others → 2. + module_output_ndim: Optional[Dict[str, int]] = None + + # Cached grids after build_model() - used by data loading + _grids: Optional[Dict[str, "HyperCommGrid"]] = field(default=None, repr=False) # Freezing options freeze_language_model: bool = False @@ -116,76 +128,58 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): freeze_modality_projections: Dict[str, bool] = field(default_factory=dict) # Fields required by ModelProviderMixin / get_model() - # These have sensible defaults for MIMO fp16: bool = False bf16: bool = True use_cpu_initialization: bool = False init_model_with_meta_device: bool = False - virtual_pipeline_model_parallel_size: Optional[int] = None - - @property - def tensor_model_parallel_size(self) -> int: - """Return LLM's tensor parallel size for compatibility with standard code paths.""" - if self.mimo_parallelism_config is None: - return 1 - llm_parallelism = self.mimo_parallelism_config.get_parallelism("llm") - return llm_parallelism.tensor_model_parallel_size - - @property - def pipeline_model_parallel_size(self) -> int: - """Return LLM's pipeline parallel size for compatibility with standard code paths.""" - if self.mimo_parallelism_config is None: - return 1 - llm_parallelism = self.mimo_parallelism_config.get_parallelism("llm") - return llm_parallelism.pipeline_model_parallel_size - - @property - def context_parallel_size(self) -> int: - """Return LLM's context parallel size for compatibility with standard code paths.""" - if self.mimo_parallelism_config is None: - return 1 - llm_parallelism = self.mimo_parallelism_config.get_parallelism("llm") - return llm_parallelism.context_parallel_size def build_infra(self) -> MimoModelInfra: """Build MIMO parallelism infrastructure. This method builds HyperCommGrids, ProcessGroupCollections, and topology - for MIMO's heterogeneous parallelism. It does not mutate provider state. - Use get_or_build_infra() when cached reuse is desired. + for MIMO's heterogeneous parallelism. It is idempotent and does not + mutate provider state (results are not cached). Can be called before or after provide(). Call finalize() first to validate the parallelism configuration. Returns: - MimoModelInfra containing grids, topology, pg_collections, and - the list of modules this rank participates in. + MimoModelInfra containing grids, topology, pg_collections, + and the list of modules this rank participates in. """ if self.mimo_parallelism_config is not None: grids = build_hypercomm_grids(self.mimo_parallelism_config) pg_collections = self._get_pg_collections_from_grids(grids) - topology = _default_topology(self.mimo_parallelism_config) else: - # No parallelism - use global process groups grids = {} pg_collections = {} - topology = {} + + if self.topology is not None: + topology = self.topology + else: + topology = {name: [MIMO_LANGUAGE_MODULE_KEY] for name in self.modality_submodules_spec} | { + MIMO_LANGUAGE_MODULE_KEY: [] + } + + # Cache grids for later use (e.g., data loading) + object.__setattr__(self, "_grids", grids) participating_modules = [name for name, pg in pg_collections.items() if pg is not None] + # Derive module output tensor dimensionality if not explicitly configured. + if self.module_output_ndim is not None: + output_ndim = self.module_output_ndim + else: + output_ndim = {name: 3 if name == MIMO_LANGUAGE_MODULE_KEY else 2 for name in grids} + return MimoModelInfra( module_to_grid_map=grids, topology=topology, pg_collections=pg_collections, participating_modules=participating_modules, + module_output_ndim=output_ndim, ) - def get_or_build_infra(self) -> MimoModelInfra: - """Return cached MIMO infrastructure, building it once if needed.""" - if self._cached_infra is None: - object.__setattr__(self, "_cached_infra", self.build_infra()) - return self._cached_infra - def _get_pg_collections_from_grids( self, grids: Dict[str, "HyperCommGrid"], @@ -196,25 +190,16 @@ def _get_pg_collections_from_grids( Returns None for modules this rank doesn't participate in. """ pg_collections: Dict[str, Optional[ProcessGroupCollection]] = {} - current_rank = dist.get_rank() for module_name, grid in grids.items(): - # Check if current rank is in this grid's range - if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): - pp_group = grid.get_pg(["pp"]) - - assert ( - self.virtual_pipeline_model_parallel_size is None or self.virtual_pipeline_model_parallel_size <= 1 - ), ( - f"VPP (virtual_pipeline_model_parallel_size={self.virtual_pipeline_model_parallel_size}) " - f"is not supported with MIMO embedding groups. pp_ranks[0]/pp_ranks[-1] do not " - f"reliably identify embedding stages under VPP." - ) + pp_group = grid.get_pg(["pp"]) - # Create embedding groups for PP > 1 (collective operation on all PP ranks) - pos_embd_pg, embd_pg = create_embedding_and_position_groups(pp_group) + # dist.new_group() is a collective on the default PG — all ranks must + # call it in the same global order regardless of module membership. + pos_embd_pg, embd_pg = populate_embedding_and_position_groups(pp_group) - # Only assign embedding groups to ranks that should have them + # Only build a full PG collection for ranks that participate in this module. + if grid.is_current_rank_in_grid(): first_stage = is_pp_first_stage(pp_group) last_stage = is_pp_last_stage(pp_group) @@ -225,9 +210,9 @@ def _get_pg_collections_from_grids( cp=grid.get_pg(["cp"]), ep=grid.get_pg(["ep"]), dp_cp=grid.get_pg(["dp", "cp"]), - # Position embeddings only on first PP stage + mp=grid.get_pg(["tp", "pp"]), + tp_ep_pp=grid.get_pg(["tp", "ep", "pp"]), pos_embd=pos_embd_pg if first_stage else None, - # Word embeddings on first and last PP stages (for tied embeddings) embd=embd_pg if (first_stage or last_stage) else None, ) else: @@ -239,12 +224,18 @@ def _inject_pg_collection_into_language_spec( self, spec: ModuleSpec, pg_collection: ProcessGroupCollection, + pre_process: Optional[bool] = None, + post_process: Optional[bool] = None, ) -> ModuleSpec: - """Deep copy language model spec and inject pg_collection into params.""" + """Deep copy language model spec and inject stage-aware params.""" spec = copy.deepcopy(spec) if spec.params is None: spec.params = {} spec.params["pg_collection"] = pg_collection + if pre_process is not None: + spec.params["pre_process"] = pre_process + if post_process is not None: + spec.params["post_process"] = post_process return spec def _inject_pg_collection_into_modality_spec( @@ -298,18 +289,29 @@ def provide( consistent with other providers. This method returns a CPU model. Raises: - ValueError: If this rank doesn't participate in any module - (indicates invalid parallelism configuration). + ValueError: If language_model_spec is not set, or if this rank + doesn't participate in any module. """ + if self.language_model_spec is None: + raise ValueError( + "language_model_spec must be set before calling provide(). " + "Set it directly or use a subclass that populates it in __post_init__." + ) + # Build infrastructure - infra = self.get_or_build_infra() + infra = self.build_infra() # Inject pg_collection into language model spec language_spec = self.language_model_spec if self.mimo_parallelism_config: - llm_pg = infra.pg_collections.get("llm") + llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if llm_pg is not None: - language_spec = self._inject_pg_collection_into_language_spec(language_spec, llm_pg) + language_spec = self._inject_pg_collection_into_language_spec( + language_spec, + llm_pg, + pre_process=is_pp_first_stage(llm_pg.pp), + post_process=is_pp_last_stage(llm_pg.pp), + ) # Inject pg_collection into modality specs modality_specs: Dict[str, ModuleSpec] = {} @@ -324,6 +326,7 @@ def provide( language_model_spec=language_spec, modality_submodules_spec=modality_specs, special_token_ids=self.special_token_ids, + module_to_grid_map=(infra.module_to_grid_map if self.mimo_parallelism_config is not None else None), ) mimo_model = MimoModel(mimo_model_config) @@ -343,8 +346,8 @@ def provide_distributed_model( use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, wrap_with_ddp: bool = True, - data_parallel_random_init: bool = False, - use_cpu_initialization: Optional[bool] = False, + data_parallel_random_init: bool = True, + use_cpu_initialization: Optional[bool] = None, init_model_with_meta_device: Optional[bool] = None, pre_wrap_hook: Optional[ Union[ @@ -353,7 +356,6 @@ def provide_distributed_model( ] ] = None, post_wrap_hook: Optional[Callable[[List[MegatronModule]], List[MegatronModule]]] = None, - mixed_precision_wrapper: Optional[Callable] = None, ) -> List[MegatronModule]: """Build MIMO model with heterogeneous parallelism and DDP wrapping. @@ -370,7 +372,7 @@ def provide_distributed_model( 4. Applies pre-wrap hooks 5. Moves to device 6. Wraps each submodule with DDP using its own pg_collection - 7. Applies mixed precision (Float16Module) + 7. Casts to fp16/bf16 (direct casting, not Float16Module) 8. Applies post-wrap hooks Args: @@ -387,7 +389,6 @@ def provide_distributed_model( init_model_with_meta_device: Initialize model on meta device. pre_wrap_hook: Callable(s) to modify model before wrapping. post_wrap_hook: Callable to modify model after wrapping. - mixed_precision_wrapper: Wrapper for mixed precision (e.g., Float16Module). Returns: List containing the wrapped MimoModel. @@ -396,9 +397,6 @@ def provide_distributed_model( ValueError: If this rank doesn't participate in any module (indicates invalid parallelism configuration). """ - # Import here to avoid circular imports - from megatron.core.transformer.module import Float16Module - if wrap_with_ddp and ddp_config is None: raise ValueError("ddp_config is required when wrap_with_ddp is True") @@ -410,8 +408,8 @@ def provide_distributed_model( # Finalize parallelism config self.finalize() - # Build infrastructure once and reuse in provide() - infra = self.get_or_build_infra() + # Build infrastructure + infra = self.build_infra() # Get the model model = self.provide() @@ -427,19 +425,36 @@ def provide_distributed_model( if result is not None: model_list = result + # Resolve initialization settings from provider defaults if not specified + local_use_cpu_init = ( + use_cpu_initialization if use_cpu_initialization is not None else self.use_cpu_initialization + ) + local_init_meta_device = ( + init_model_with_meta_device + if init_model_with_meta_device is not None + else self.init_model_with_meta_device + ) + # Move to device - if not use_cpu_initialization and not init_model_with_meta_device: + if not local_use_cpu_init and not local_init_meta_device: for m in model_list: m.cuda(torch.cuda.current_device()) - # Set variable_seq_lengths=True for multimodule pipeline support (required by PR 3129) + # Set variable_seq_lengths=True for multimodule pipeline support (required by PR 3212) # This must be set before the model is used in the training loop for m in model_list: model_config = get_model_config(m) model_config.variable_seq_lengths = True - # Wrap submodules with DDP (before Float16Module) - # MIMO uses per-submodule DDP for heterogeneous parallelism + # Dtype cast must precede DDP wrapping so hooks bind to final parameters. + use_fp16 = fp16 if fp16 is not None else self.fp16 + use_bf16 = bf16 if bf16 is not None else self.bf16 + if use_fp16: + model_list = [m.half() for m in model_list] + elif use_bf16: + model_list = [m.bfloat16() for m in model_list] + + # Per-submodule DDP for heterogeneous parallelism if wrap_with_ddp and ddp_config is not None and self.mimo_parallelism_config: model_list = [ wrap_mimo_model_distributed( @@ -452,19 +467,6 @@ def provide_distributed_model( for m in model_list ] - # Apply mixed precision wrapper - use_fp16 = fp16 if fp16 is not None else self.fp16 - use_bf16 = bf16 if bf16 is not None else self.bf16 - if (use_fp16 or use_bf16) and mixed_precision_wrapper is not None: - model_config = get_model_config(model_list[0]) - model_list = [mixed_precision_wrapper(model_config, m) for m in model_list] - elif (use_fp16 or use_bf16) and mixed_precision_wrapper is None: - # Use default Float16Module - model_config = get_model_config(model_list[0]) - model_config.fp16 = use_fp16 - model_config.bf16 = use_bf16 - model_list = [Float16Module(model_config, m) for m in model_list] - # Apply post-wrap hooks if final_post_wrap_hook: result = final_post_wrap_hook(model_list) @@ -503,17 +505,16 @@ def initialize_model_parallel( seed_kwargs: Optional[dict] = None, **model_parallel_kwargs, ) -> None: - """MIMO uses its own parallelism via MimoParallelismConfig. - - This method is a no-op for MIMO. Parallelism is set up in build_infra() - using HyperCommGrids, not global mpu state. + """MIMO uses per-module HyperCommGrids, not global MPU state. - Note: - Call finalize() to validate the parallelism configuration, then - build_infra() to create the HyperCommGrids. + Raises NotImplementedError to prevent accidental global MPU initialization, + which would corrupt process groups for heterogeneous parallelism. + Use finalize() + build_infra() instead. """ - # MIMO manages its own parallelism via HyperCommGrids - pass + raise NotImplementedError( + "MIMO does not use global model parallelism initialization. " + "Use finalize() to validate config and build_infra() to create HyperCommGrids." + ) def _apply_freezing(self, model: MimoModel) -> None: """Apply freezing based on configuration.""" @@ -549,7 +550,9 @@ def finalize(self) -> None: ranks in the world (validated by MimoParallelismConfig.finalize()). """ if self.mimo_parallelism_config is not None: - world_size = dist.get_world_size() if dist.is_initialized() else None - self.mimo_parallelism_config.finalize(world_size) - # Invalidate cached infra in case parallelism config changed. - object.__setattr__(self, "_cached_infra", None) + if not dist.is_initialized(): + raise RuntimeError( + "MIMO requires torch.distributed to be initialized before finalize(). " + "Call torch.distributed.init_process_group() first." + ) + self.mimo_parallelism_config.finalize(dist.get_world_size()) diff --git a/src/megatron/bridge/training/mimo_parallel_utils.py b/src/megatron/bridge/training/mimo_parallel_utils.py new file mode 100644 index 0000000000..94dd801ef8 --- /dev/null +++ b/src/megatron/bridge/training/mimo_parallel_utils.py @@ -0,0 +1,292 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Multi-module process group utilities for MIMO heterogeneous parallel training. + +This module provides utilities for building process group structures and handling +gradients across modules with different parallelism configurations. + +Key functions: +- unwrap_mimo_model(): Unwrap Float16Module/DDP to get underlying MimoModel +- build_pg_collection_for_schedule(): Build pg_collection compatible with schedule +- multimodule_no_sync(): Context manager for gradient sync during microbatch accumulation +- finalize_model_grads_multimodule(): Finalize gradients for each module +- zero_grad_buffer_for_multimodule(): Reset gradient buffers for all modules +- validate_no_stub_ranks(): Ensure every rank participates in at least one module +- validate_data_loader_contract(): Validate data loading constraints +""" + +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, List, Tuple + +import torch.distributed as dist +from megatron.core.distributed.finalize_model_grads import finalize_model_grads as _finalize_model_grads +from megatron.core.models.mimo import MimoModel +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + +from megatron.bridge.models.mimo.mimo_provider import MimoModelInfra + + +if TYPE_CHECKING: + from megatron.core.hyper_comm_grid import HyperCommGrid + + +logger = logging.getLogger(__name__) + + +def unwrap_mimo_model(model) -> MimoModel: + """Unwrap Float16Module/DDP wrappers to get the underlying MimoModel. + + When using mixed precision (bf16/fp16), models are wrapped in Float16Module. + This function unwraps the model to access MimoModel-specific attributes + like `role`, `mimo_config`, `language_model`, `modality_submodules`, etc. + + Args: + model: A MimoModel or a wrapped version (Float16Module, DDP). + + Returns: + The underlying MimoModel instance. + + Raises: + RuntimeError: If the model cannot be unwrapped to a MimoModel. + """ + unwrapped = model + while not isinstance(unwrapped, MimoModel) and hasattr(unwrapped, "module"): + unwrapped = unwrapped.module + if not isinstance(unwrapped, MimoModel): + raise RuntimeError(f"Failed to unwrap model to MimoModel, got {type(unwrapped)}") + return unwrapped + + +def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool: + """Check if current rank participates in the given grid. + + Args: + grid: HyperCommGrid to check participation in. + + Returns: + True if current rank is within the grid's rank range. + """ + current_rank = dist.get_rank() + return grid.rank_offset <= current_rank < (grid.rank_offset + grid.size) + + +def get_module_to_grid_tuple( + mimo_model: MimoModel, + infra: MimoModelInfra, +) -> List[Tuple]: + """Build list of (module, grid) tuples for all modules the current rank participates in. + + Args: + mimo_model: The MimoModel instance. + infra: MimoModelInfra containing module_to_grid_map. + + Returns: + List of (module, grid) tuples for modules this rank participates in. + """ + module_to_grid_tuple = [] + + # Unwrap Float16Module/DDP if present (used in mixed precision training) + unwrapped_model = unwrap_mimo_model(mimo_model) + + for module_name, grid in infra.module_to_grid_map.items(): + if not is_current_rank_in_grid(grid): + continue + + # Get the actual module from the unwrapped model + if module_name == MIMO_LANGUAGE_MODULE_KEY: + module = unwrapped_model.language_model + elif hasattr(unwrapped_model, "modality_submodules") and module_name in unwrapped_model.modality_submodules: + module = unwrapped_model.modality_submodules[module_name] + else: + logger.warning(f"Module {module_name} not found in MimoModel, skipping") + continue + + module_to_grid_tuple.append((module, grid)) + + return module_to_grid_tuple + + +def build_pg_collection_for_schedule(infra: MimoModelInfra): + """Build pg_collection compatible with schedule. + + Primary: Use MultiModuleProcessGroupCollection if PR 3212 allows + missing LLM PG on encoder-only ranks. + Fallback: Return list of ProcessGroupCollections for participating modules. + + IMPORTANT: Uses infra.pg_collections directly. Do NOT rebuild PGs. + + Args: + infra: MimoModelInfra with pg_collections for each module. + + Returns: + MultiModuleProcessGroupCollection or list of ProcessGroupCollections. + """ + try: + from megatron.core.process_groups_config import MultiModuleProcessGroupCollection + + module_pgs = {k: v for k, v in infra.pg_collections.items() if v is not None} + if not module_pgs: + raise ValueError("module_pgs dict cannot be empty") + language_model_module_name = MIMO_LANGUAGE_MODULE_KEY if MIMO_LANGUAGE_MODULE_KEY in module_pgs else None + return MultiModuleProcessGroupCollection( + module_pgs=module_pgs, + language_model_module_name=language_model_module_name, + ) + except (ImportError, ValueError, TypeError) as e: + logger.warning(f"MultiModuleProcessGroupCollection failed ({e}), using list-based fallback") + return [pg for pg in infra.pg_collections.values() if pg is not None] + + +@contextmanager +def multimodule_no_sync(*, module_to_grid_tuple: List[Tuple]): + """Context manager to disable gradient sync for all modules during microbatch accumulation. + + This function is designed to be used with functools.partial() to pre-bind + the module_to_grid_tuple parameter, since the schedule calls no_sync_func() + with no arguments. + + Args: + module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial). + + Yields: + None - context manager for gradient sync control. + """ + contexts = [] + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + contexts.append(module.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + +def finalize_model_grads_multimodule( + model, + num_tokens=None, + pg_collection=None, + force_all_reduce=None, + *, + infra: MimoModelInfra, + module_to_grid_tuple: List[Tuple], +): + """Finalize gradients for each module using infra.pg_collections. + + IMPORTANT: Signature matches schedule's call pattern: + config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag) + + The `infra` and `module_to_grid_tuple` parameters are pre-bound via partial(). + We ignore the schedule-provided `pg_collection` and use per-module PGs. + + Args: + model: Model list (passed by schedule, ignored - we use module_to_grid_tuple). + num_tokens: Token count for gradient scaling. + pg_collection: Schedule-provided PG (ignored - we use per-module PGs). + force_all_reduce: Schedule-provided flag (ignored - per-module PGs control sync). + infra: MimoModelInfra with per-module pg_collections (keyword-only, bound via partial). + module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial). + """ + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + # Get the module's pg_collection from infra + # Find the module name by matching the grid + module_pg = None + for module_name, mod_grid in infra.module_to_grid_map.items(): + if mod_grid is grid: + module_pg = infra.pg_collections.get(module_name) + break + + if module_pg is not None: + _finalize_model_grads([module], num_tokens=num_tokens, pg_collection=module_pg) + + +def zero_grad_buffer_for_multimodule(module_to_grid_tuple: List[Tuple]): + """Reset gradient buffers for all DDP-wrapped modules. + + Args: + module_to_grid_tuple: List of (module, grid) tuples. + """ + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + if hasattr(module, "zero_grad_buffer"): + module.zero_grad_buffer() + + +def validate_no_stub_ranks(module_to_grid_map: Dict[str, "HyperCommGrid"], world_size: int): + """Ensure every rank participates in at least one module. + + Stub ranks (ranks not participating in any module) are NOT supported. + This validation runs at setup time to fail fast with a clear error. + + Args: + module_to_grid_map: Mapping of module names to their HyperCommGrids. + world_size: Total number of ranks in the world. + + Raises: + ValueError: If any rank doesn't participate in a module. + """ + participating_ranks = set() + for module_name, grid in module_to_grid_map.items(): + # Add all ranks in this grid's range + for rank in range(grid.rank_offset, grid.rank_offset + grid.size): + participating_ranks.add(rank) + + all_ranks = set(range(world_size)) + stub_ranks = all_ranks - participating_ranks + + if stub_ranks: + raise ValueError( + f"Ranks {sorted(stub_ranks)} do not participate in any module. " + f"Stub ranks are not supported. Adjust parallelism config to use all {world_size} GPUs, " + f"or reduce world_size to {len(participating_ranks)}." + ) + + +def validate_data_loader_contract( + infra: MimoModelInfra, + global_batch_size: int, + micro_batch_size: int, + num_microbatches: int, +): + """Validate data loading constraints for multimodule training. + + Checks: + - Global batch size divisible by all module DP sizes + - Micro-batch size consistent with per-module sharding + - num_microbatches * micro_batch_size == global_batch_size / DP_size (per module) + + Args: + infra: MimoModelInfra with module_to_grid_map. + global_batch_size: Total batch size across all data parallel ranks. + micro_batch_size: Batch size per microbatch. + num_microbatches: Number of microbatches per iteration. + + Raises: + ValueError: If any constraint is violated. + """ + for module_name, grid in infra.module_to_grid_map.items(): + # Get DP size from grid + dp_size = grid.get_pg_size(["dp"]) + + # Check global batch divisibility + if global_batch_size % dp_size != 0: + raise ValueError(f"Global batch size {global_batch_size} not divisible by {module_name} DP size {dp_size}") + + # Check micro-batch alignment + per_dp_batch = global_batch_size // dp_size + expected = num_microbatches * micro_batch_size + if per_dp_batch != expected: + raise ValueError( + f"Microbatch mismatch for {module_name}: " + f"{num_microbatches} * {micro_batch_size} = {expected} != {per_dp_batch} " + f"(global_batch / DP_size)" + ) diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py new file mode 100644 index 0000000000..eb6598dd89 --- /dev/null +++ b/src/megatron/bridge/training/mimo_step.py @@ -0,0 +1,208 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""MIMO-specific forward step function for use with pipeline schedules. + +This module provides the forward step function for MIMO model training. +Key design notes (per PR 3212): +- The schedule expects dict-based outputs: {module_name: tensor} instead of single tensors +- The MimoModel's forward returns output tensors that the schedule sends via MultiModulePipelineCommunicator +- The schedule's backward_step_multimodule() handles dict-based backward pass automatically +- Only the LLM module produces a loss - encoders just produce activations +""" + +from __future__ import annotations + +import logging +from functools import partial +from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple + +import torch +from megatron.core.models.mimo import MimoModel +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY + +from megatron.bridge.training.mimo_parallel_utils import unwrap_mimo_model +from megatron.bridge.training.state import GlobalState + + +if TYPE_CHECKING: + pass + + +logger = logging.getLogger(__name__) + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor) -> Tuple: + """Loss function for MIMO model training. + + Called at the terminal stage (LLM's last PP stage). + + Args: + loss_mask: Mask indicating which tokens contribute to the loss. + output_tensor: Model output tensor (losses per token). + + Returns: + Tuple of (total_loss, num_tokens, {'lm loss': reporting_loss}). + + Note: + Only the LLM module produces a loss. Encoders produce activations + that are consumed by the LLM, but don't have their own loss. + """ + losses = output_tensor.float() + + loss_mask = loss_mask.contiguous().view(-1).float() + + total_tokens = loss_mask.sum().clone().detach().to(torch.int) + total_loss = torch.sum(losses.view(-1) * loss_mask) + reporting_loss = torch.cat([total_loss.clone().detach().view(1), total_tokens.view(1)]) + + return (total_loss, total_tokens, {"lm loss": reporting_loss}) + + +def get_batch(data_iterator: Iterable) -> Optional[Dict[str, torch.Tensor]]: + """Get batch from data iterator. + + Returns dict with: + - input_ids, labels, loss_mask, position_ids (for LLM) + - modality_inputs: {modality_name: preprocessed_tensors} (for encoders) + + Uses existing MimoDataset format from Phase 3. + + Args: + data_iterator: Iterator over the dataset. + + Returns: + Batch dictionary or None if iterator is exhausted. + """ + if data_iterator is None: + return None + + try: + batch = next(data_iterator) + except StopIteration: + return None + + # Move tensors to GPU if not already there + def _move_to_cuda(obj): + if isinstance(obj, torch.Tensor): + return obj.cuda(non_blocking=True) if not obj.is_cuda else obj + if isinstance(obj, dict): + return {k: _move_to_cuda(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + converted = [_move_to_cuda(v) for v in obj] + return type(obj)(converted) + return obj + + if batch is not None: + batch = _move_to_cuda(batch) + + return batch + + +def forward_step( + state: GlobalState, + data_iterator: Iterable, + model: MimoModel, +) -> Tuple[torch.Tensor, Optional[partial]]: + """Forward step for MIMO model training. + + Uses 3-arg signature with GlobalState for Bridge compatibility. + The training loop wraps this with prepare_forward_step_func() which: + - Injects GlobalState automatically if forward_step accepts it + - Provides access to state.timers, state.cfg, state.train_state + + The MimoModel handles dict-based tensor flow internally: + - Encoder modules produce activations sent via BridgeCommunicator + - LLM module receives encoder outputs and produces loss + + At terminal stage: returns (loss_tensor, loss_func) + At intermediate stages: returns (output_dict, None) - schedule handles communication + + GUARDRAIL: At last stage, assert output is scalar tensor (not dict) to catch + misconfigurations early with a clear error message. + + Args: + state: GlobalState containing timers, config, train_state. + data_iterator: Iterator over the dataset. + model: MimoModel instance. + + Returns: + Tuple of (output_tensor, loss_function or None). + """ + # Get the model's role to determine if we're at first pipeline stage + mimo_model = unwrap_mimo_model(model) + + # Determine if this rank needs data. + # - LLM ranks: first stage needs input_ids; last stage needs labels/loss_mask. + # - Modality ranks: only first stage needs raw modality inputs. + needs_data = True + if mimo_model.role is not None: + if mimo_model.role.has_language_module: + module_name = MIMO_LANGUAGE_MODULE_KEY + is_first_stage = mimo_model.role.is_first_stage(module_name) + is_last_stage = mimo_model.role.is_last_stage(module_name) + needs_data = is_first_stage or is_last_stage + elif mimo_model.role.has_modality_modules: + modality_modules = mimo_model.role.modality_module_names + needs_data = any(mimo_model.role.is_first_stage(mod) for mod in modality_modules) + + if needs_data: + data_batch = get_batch(data_iterator) + if data_batch is None: + raise RuntimeError( + "get_batch returned None at a stage that requires data. " + "This indicates a data-loading or parallelism misconfiguration." + ) + else: + # Non-data stages consume hidden states from pipeline input tensors. + data_batch = { + "input_ids": None, + "position_ids": None, + "attention_mask": None, + "labels": None, + "loss_mask": None, + "modality_inputs": None, + } + + # Extract loss_mask before forward pass + loss_mask = data_batch.get("loss_mask") + + # Run forward pass + # MimoModel.forward() returns (output_tensor, loss_mask) or just output_tensor + output = model(**data_batch) + + # Handle tuple return from model + if isinstance(output, tuple): + output_tensor, model_loss_mask = output + # Use model-provided loss_mask if available + if model_loss_mask is not None: + loss_mask = model_loss_mask + else: + output_tensor = output + + # Check if we're at the last pipeline stage for the language module + # mimo_model was already unwrapped at the start of this function + if mimo_model.role is None: + is_last_stage = True + elif mimo_model.role.has_language_module: + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) + else: + is_last_stage = False + + if is_last_stage: + # GUARDRAIL: Verify scalar loss at last stage + if isinstance(output_tensor, dict): + raise ValueError( + f"Last pipeline stage must return scalar loss tensor, got dict with keys: {output_tensor.keys()}. " + f"Ensure the LLM module's final stage produces a loss, not activations." + ) + + # Return output and loss function + if loss_mask is not None: + return output_tensor, partial(loss_func, loss_mask) + else: + # Create default loss mask if not provided + logger.warning("No loss_mask provided, using all-ones mask") + default_mask = torch.ones_like(output_tensor) + return output_tensor, partial(loss_func, default_mask) + + # Intermediate stage - return output for activation passing + return output_tensor, None diff --git a/src/megatron/bridge/training/pretrain_mimo.py b/src/megatron/bridge/training/pretrain_mimo.py new file mode 100644 index 0000000000..a33e199648 --- /dev/null +++ b/src/megatron/bridge/training/pretrain_mimo.py @@ -0,0 +1,388 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Entry point for MIMO pretraining. + +This module provides the entry point for MIMO pretraining with heterogeneous +parallelism support. It uses a setup_mimo() helper that composes with existing +setup logic rather than duplicating pretrain.py. + +Key components: +- setup_mimo(): MIMO-specific setup helper +- pretrain_mimo(): Entry point for MIMO pretraining +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional + +import torch.distributed as dist +from megatron.core.models.mimo import MimoModel +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.utils import get_model_config + +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.mimo_parallel_utils import ( + build_pg_collection_for_schedule, + get_module_to_grid_tuple, + is_current_rank_in_grid, + unwrap_mimo_model, + validate_no_stub_ranks, +) +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.train_mimo import train_mimo + + +if TYPE_CHECKING: + from megatron.core.models.mimo.optimizer import MimoOptimizer + from megatron.core.optimizer.optimizer_param_scheduler import OptimizerParamScheduler + from megatron.core.process_groups_config import MultiModuleProcessGroupCollection + + from megatron.bridge.models.mimo.mimo_provider import MimoModelInfra, MimoModelProvider + + +logger = logging.getLogger(__name__) + + +def _set_mimo_random_seeds( + cfg: ConfigContainer, + mimo_infra: "MimoModelInfra", +) -> None: + """Initialize random seeds with per-module TP/PP awareness. + + Mirrors the standard path's ``_set_random_seed()`` but derives TP/PP ranks + from the per-module HyperCommGrids instead of global MPU state. + + Must be called **after** ``build_infra()`` (grids exist) and **before** + ``provide_distributed_model()`` (weight init needs the CUDA RNG tracker). + """ + import random + + import numpy as np + import torch + from megatron.core import tensor_parallel + + seed = cfg.rng.seed + + current_rank = dist.get_rank() + + # Find which module this rank belongs to and get its TP/PP ranks. + tp_rank = 0 + pp_rank = 0 + for module_name, grid in mimo_infra.module_to_grid_map.items(): + if is_current_rank_in_grid(grid): + tp_rank = dist.get_group_rank(grid.get_pg(["tp"]), current_rank) + pp_rank = dist.get_group_rank(grid.get_pg(["pp"]), current_rank) + break + + # Different PP stages get different seeds (consistent with standard path). + seed = seed + (100 * pp_rank) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.device_count() > 0: + tensor_parallel.model_parallel_cuda_manual_seed(seed, tp_rank=tp_rank, ep_rank=0, etp_rank=0) + + logger.info( + f"Rank {current_rank}: Initialized MIMO random seeds (base_seed={seed}, tp_rank={tp_rank}, pp_rank={pp_rank})" + ) + + +@dataclass +class MimoSetupOutput: + """Output from setup_mimo() containing all components needed for training. + + Attributes: + model: MimoModel (distributed, DDP-wrapped). + mimo_infra: MimoModelInfra (grids, topology, pg_collections). + multimodule_pg_collection: PG collection for schedule. + multimodule_communicator: MultiModulePipelineCommunicator for P2P. + module_to_grid_tuple: List of (module, grid) tuples for gradient handling. + optimizer: MimoOptimizer (None when ``build_optimizer=False``). + schedulers: Per-module LR schedulers (empty when ``build_optimizer=False``). + train_data_iterator: Training data iterator. + valid_data_iterator: Validation data iterator (optional). + global_state: GlobalState containing timers, config, train_state. + """ + + model: "MimoModel" + mimo_infra: "MimoModelInfra" + multimodule_pg_collection: "MultiModuleProcessGroupCollection" + multimodule_communicator: MultiModulePipelineCommunicator + module_to_grid_tuple: List + optimizer: Optional["MimoOptimizer"] + schedulers: Dict[str, "OptimizerParamScheduler"] + train_data_iterator: Iterator + valid_data_iterator: Optional[Iterator] + global_state: GlobalState + + +def setup_mimo( + cfg: ConfigContainer, + mimo_provider: "MimoModelProvider", + build_data_iterators_fn: Optional[Callable] = None, + build_optimizer: bool = True, + global_state: Optional[GlobalState] = None, +) -> MimoSetupOutput: + """MIMO-specific setup helper. + + This function sets up all components needed for MIMO training: + - Builds distributed model via MimoModelProvider + - Builds MIMO infrastructure (grids, topology, pg_collections) + - Creates MultiModulePipelineCommunicator + - Creates MimoOptimizer and per-module LR schedulers (when ``build_optimizer=True``) + - Builds data iterators (if function provided) + - Validates configuration + + Args: + cfg: ConfigContainer with training configuration. ``cfg.optimizer`` + is used to create the optimizer when ``build_optimizer=True``. + mimo_provider: MimoModelProvider for building model and infrastructure. + build_data_iterators_fn: Optional function to build data iterators. + Should have signature: (cfg, mimo_infra) -> (train_iter, valid_iter) + build_optimizer: Whether to create optimizer and schedulers. Set to + ``False`` for inference or evaluation-only callers. + global_state: Optional GlobalState. If not provided, creates a new one. + + Returns: + MimoSetupOutput containing all components for training. + + Reuses from setup.py: + - Logging setup (via global_state) + - Timer infrastructure (via global_state) + """ + # Create GlobalState if not provided + if global_state is None: + from megatron.core.timers import Timers + + from megatron.bridge.training.state import GlobalState, TrainState + + timers = Timers( + log_level=cfg.logger.timing_log_level, + log_option=cfg.logger.timing_log_option, + ) + train_state = TrainState() + global_state = GlobalState() + global_state.cfg = cfg + global_state._timers = timers + global_state.train_state = train_state + + logger.info(f"Rank {dist.get_rank()}: Setting up MIMO training") + + # Finalize and build infrastructure + mimo_provider.finalize() + mimo_infra = mimo_provider.build_infra() + + # Validate no stub ranks + world_size = dist.get_world_size() + validate_no_stub_ranks(mimo_infra.module_to_grid_map, world_size) + + # Initialize per-module random seeds before model construction. + # MIMO bypasses initialize_megatron() (to avoid global MPU corruption), which + # also skips model_parallel_cuda_manual_seed(). Without it, GPU weight init and + # TP-region dropout crash because CudaRNGStatesTracker is empty. We look up the + # per-module TP/PP ranks from HyperCommGrids and pass them explicitly. + _set_mimo_random_seeds(cfg, mimo_infra) + + logger.info(f"Rank {dist.get_rank()}: Building distributed model") + + # Build distributed model + # Use DDP config from cfg if available + from megatron.core.distributed import DistributedDataParallelConfig + + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=getattr(cfg.train, "grad_reduce_in_fp32", False), + overlap_grad_reduce=getattr(cfg.train, "overlap_grad_reduce", True), + use_distributed_optimizer=getattr(cfg.train, "use_distributed_optimizer", False), + check_for_nan_in_grad=getattr(cfg.train, "check_for_nan_in_grad", False), + ) + + model_list = mimo_provider.provide_distributed_model( + ddp_config=ddp_config, + fp16=cfg.model.fp16 if hasattr(cfg.model, "fp16") else False, + bf16=cfg.model.bf16 if hasattr(cfg.model, "bf16") else True, + ) + model = model_list[0] + + logger.info(f"Rank {dist.get_rank()}: Creating multimodule communicator") + + # Create MultiModulePipelineCommunicator + # IMPORTANT: MimoModel produces SBH tensors (seq, batch, hidden), NOT BSH + # See MimoModel.align_embeddings_by_token_positions() which returns [s, b, h] + model_config = get_model_config(model) + + # Ensure pipeline_dtype is set for P2P communication (required when any module uses PP > 1) + # The model config may not have this set if individual modules don't use PP + import torch + + if model_config.pipeline_dtype is None: + if getattr(model_config, "bf16", False): + model_config.pipeline_dtype = torch.bfloat16 + elif getattr(model_config, "fp16", False): + model_config.pipeline_dtype = torch.float16 + else: + model_config.pipeline_dtype = torch.float32 + + multimodule_communicator = MultiModulePipelineCommunicator( + mimo_infra.module_to_grid_map, + mimo_infra.topology, + model_config, + dim_mapping={"s": 0, "b": 1, "h": 2}, # SBH mapping - matches MimoModel output + module_output_ndim=mimo_infra.module_output_ndim, + ) + + # Build pg_collection for schedule + multimodule_pg_collection = build_pg_collection_for_schedule(mimo_infra) + + # Build module-to-grid tuple for gradient operations + module_to_grid_tuple = get_module_to_grid_tuple(model, mimo_infra) + + # Build optimizer and per-module LR schedulers + optimizer = None + schedulers: Dict[str, "OptimizerParamScheduler"] = {} + if build_optimizer: + unwrapped_model = unwrap_mimo_model(model) + if mimo_infra.module_to_grid_map: + assert unwrapped_model.mimo_config.module_to_grid_map is not None, ( + "MimoModelConfig.module_to_grid_map must be set at model construction time. " + "Ensure MimoModelProvider.provide() passes module_to_grid_map for MIMO parallelism." + ) + + logger.info(f"Rank {dist.get_rank()}: Creating MimoOptimizer") + from megatron.core.models.mimo.optimizer import get_mimo_optimizer + + opt_config = cfg.optimizer + if hasattr(opt_config, "finalize"): + opt_config.finalize() + + optimizer = get_mimo_optimizer(unwrapped_model, opt_config) + + # Auto-create per-module LR schedulers + cfg._calculate_scheduler_steps() + from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler + + for name, info in optimizer.module_infos.items(): + if info.is_active and info.optimizer is not None: + schedulers[name] = OptimizerParamScheduler( + info.optimizer, + init_lr=cfg.scheduler.lr_warmup_init, + max_lr=opt_config.lr, + min_lr=opt_config.min_lr, + lr_warmup_steps=cfg.scheduler.lr_warmup_steps, + lr_decay_steps=cfg.scheduler.lr_decay_steps, + lr_decay_style=cfg.scheduler.lr_decay_style, + start_wd=cfg.scheduler.start_weight_decay, + end_wd=cfg.scheduler.end_weight_decay, + wd_incr_steps=cfg.scheduler.wd_incr_steps, + wd_incr_style=cfg.scheduler.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=cfg.scheduler.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=cfg.scheduler.override_opt_param_scheduler, + wsd_decay_steps=cfg.scheduler.wsd_decay_steps, + lr_wsd_decay_style=cfg.scheduler.lr_wsd_decay_style, + ) + logger.info(f"Rank {dist.get_rank()}: Auto-created schedulers for modules: {list(schedulers.keys())}") + + # Build data iterators if function provided + train_data_iterator = None + valid_data_iterator = None + if build_data_iterators_fn is not None: + logger.info(f"Rank {dist.get_rank()}: Building data iterators") + train_data_iterator, valid_data_iterator = build_data_iterators_fn(cfg, mimo_infra) + + logger.info(f"Rank {dist.get_rank()}: MIMO setup complete") + + return MimoSetupOutput( + model=model, + mimo_infra=mimo_infra, + multimodule_pg_collection=multimodule_pg_collection, + multimodule_communicator=multimodule_communicator, + module_to_grid_tuple=module_to_grid_tuple, + optimizer=optimizer, + schedulers=schedulers, + train_data_iterator=train_data_iterator, + valid_data_iterator=valid_data_iterator, + global_state=global_state, + ) + + +def pretrain_mimo( + cfg: ConfigContainer, + mimo_provider: "MimoModelProvider", + forward_step_func: Callable, + build_data_iterators_fn: Callable, + schedulers: Optional[Dict[str, "OptimizerParamScheduler"]] = None, + global_state: Optional[GlobalState] = None, +) -> None: + """Entry point for MIMO pretraining. + + Steps: + 1. Call setup_mimo() to get model, optimizer, schedulers, infra, communicators + 2. Call train_mimo() with all components + + Args: + cfg: ConfigContainer with training configuration. ``cfg.optimizer`` + (a ``BridgeOptimizerConfig``, which inherits from MCore's + ``OptimizerConfig``) is used to create the ``MimoOptimizer`` + and per-module LR schedulers. + mimo_provider: MimoModelProvider for building model and infrastructure. + forward_step_func: Forward step function for training. + build_data_iterators_fn: Function to build data iterators. + Signature: (cfg, mimo_infra) -> (train_iter, valid_iter) + schedulers: Per-module learning rate schedulers {module_name: scheduler}. + If not provided, auto-created from ``cfg.optimizer`` and ``cfg.scheduler``. + global_state: Optional GlobalState. If not provided, creates a new one. + """ + logger.info("Starting MIMO pretraining") + + # MIMO: data_parallel_size is always 1 from the training loop's perspective. + # All ranks load the same global micro-batch; per-module DP sharding is handled + # by slice_batch_for_mimo() in the forward step, not by the data loader or + # training loop. Hard-coding this avoids requiring callers to set it manually + # and prevents incorrect consumed-sample / scheduler-increment calculations. + cfg.data_parallel_size = 1 + + # Initialize num-microbatches calculator if not already set. + from megatron.core import num_microbatches_calculator as nmc + + if nmc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: + nmc.init_num_microbatches_calculator( + dist.get_rank(), + getattr(cfg.train, "rampup_batch_size", None), + cfg.train.global_batch_size, + cfg.train.micro_batch_size, + cfg.data_parallel_size, + getattr(cfg.train, "decrease_batch_size_if_needed", False), + ) + + # Setup all MIMO components (model, optimizer, schedulers, data, communicators) + setup_output = setup_mimo( + cfg=cfg, + mimo_provider=mimo_provider, + build_data_iterators_fn=build_data_iterators_fn, + build_optimizer=True, + global_state=global_state, + ) + + # Allow caller-provided schedulers to override auto-created ones + final_schedulers = schedulers if schedulers else setup_output.schedulers + + logger.info(f"Rank {dist.get_rank()}: Starting training loop") + + # Run training loop + train_mimo( + forward_step_func=forward_step_func, + model=setup_output.model, + optimizer=setup_output.optimizer, + schedulers=final_schedulers, + train_data_iterator=setup_output.train_data_iterator, + valid_data_iterator=setup_output.valid_data_iterator, + global_state=setup_output.global_state, + mimo_infra=setup_output.mimo_infra, + multimodule_communicator=setup_output.multimodule_communicator, + multimodule_pg_collection=setup_output.multimodule_pg_collection, + module_to_grid_tuple=setup_output.module_to_grid_tuple, + ) + + logger.info("MIMO pretraining completed") diff --git a/src/megatron/bridge/training/train_mimo.py b/src/megatron/bridge/training/train_mimo.py new file mode 100644 index 0000000000..be2c96a069 --- /dev/null +++ b/src/megatron/bridge/training/train_mimo.py @@ -0,0 +1,462 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""MIMO Training Loop for heterogeneous multi-module training. + +This module provides the dedicated training loop for MIMO models with +heterogeneous parallelism. It uses MultiModulePipelineCommunicator for +cross-module communication and supports per-module gradient handling. + +Key differences from standard train(): +- Creates MultiModulePipelineCommunicator for cross-module communication +- Creates MultiModuleProcessGroupCollection for the schedule +- Uses forward_backward_pipelining_without_interleaving with multimodule support +- Uses zero_grad_buffer_for_multimodule() for gradient clearing +- Supports per-module optimizers + +Note: Stub ranks are disallowed - validated at setup time. +""" + +from __future__ import annotations + +import logging +from functools import partial +from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed as dist +from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel.schedules import forward_backward_pipelining_without_interleaving +from megatron.core.utils import get_model_config + +from megatron.bridge.training.checkpointing import maybe_finalize_async_save, save_checkpoint +from megatron.bridge.training.eval import evaluate_and_print_results +from megatron.bridge.training.mimo_parallel_utils import ( + build_pg_collection_for_schedule, + finalize_model_grads_multimodule, + get_module_to_grid_tuple, + multimodule_no_sync, + unwrap_mimo_model, + zero_grad_buffer_for_multimodule, +) +from megatron.bridge.training.profiling import ( + handle_profiling_step, + handle_profiling_stop, + initialize_pytorch_profiler, + should_profile_rank, +) +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.utils.train_utils import ( + prepare_forward_step_func, + training_log, +) + + +if TYPE_CHECKING: + from megatron.core.models.mimo import MimoModel + from megatron.core.models.mimo.optimizer import MimoOptimizer + from megatron.core.optimizer.optimizer_param_scheduler import OptimizerParamScheduler + from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator + from megatron.core.process_groups_config import MultiModuleProcessGroupCollection + + from megatron.bridge.models.mimo.mimo_provider import MimoModelInfra + + +logger = logging.getLogger(__name__) + + +def train_step_mimo( + forward_step_func: Callable, + data_iterator: Iterator, + model: "MimoModel", + optimizer: "MimoOptimizer", + schedulers: Dict[str, "OptimizerParamScheduler"], + global_state: GlobalState, + multimodule_communicator: "MultiModulePipelineCommunicator", + multimodule_pg_collection, + infra: "MimoModelInfra", + module_to_grid_tuple: List, + num_microbatches: int, + seq_length: int, + micro_batch_size: int, +) -> Tuple[Dict[str, torch.Tensor], Optional[float], Optional[int]]: + """Single MIMO training step. + + Args: + forward_step_func: Forward step function (wrapped with GlobalState). + data_iterator: Iterator over the dataset. + model: MimoModel instance. + optimizer: MimoOptimizer managing per-module optimizers. + schedulers: Per-module learning rate schedulers {module_name: scheduler}. + global_state: GlobalState containing timers, config, train_state. + multimodule_communicator: MultiModulePipelineCommunicator for P2P. + multimodule_pg_collection: PG collection for schedule. + infra: MimoModelInfra with grids, topology, pg_collections. + module_to_grid_tuple: List of (module, grid) tuples. + num_microbatches: Number of microbatches per iteration. + seq_length: Sequence length. + micro_batch_size: Micro batch size. + + Returns: + Tuple of (loss_dict, skipped_iter, grad_norm, num_zeros_in_grad). + """ + timers = global_state.timers + + # Zero gradients for all modules + zero_grad_buffer_for_multimodule(module_to_grid_tuple) + + # Run forward-backward schedule + timers("forward-backward", log_level=1).start(barrier=False) + + losses_reduced = forward_backward_pipelining_without_interleaving( + forward_step_func=forward_step_func, + data_iterator=data_iterator, + model=[model], + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + p2p_communicator=multimodule_communicator, + pg_collection=multimodule_pg_collection, + ) + + timers("forward-backward").stop() + + # Optimizer step - MimoOptimizer handles all modules and computes global grad norm + timers("optimizer", log_level=1).start(barrier=False) + + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + + timers("optimizer").stop() + + # Step learning rate schedulers + if update_successful: + increment = num_microbatches * micro_batch_size * global_state.cfg.data_parallel_size + for module_name, scheduler in schedulers.items(): + if scheduler is not None: + scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + loss_dict = {} + if losses_reduced: + is_last_stage = False + # Access role from unwrapped model (handles Float16Module wrapper) + mimo_model = unwrap_mimo_model(model) + if mimo_model.role is None: + is_last_stage = True + elif mimo_model.role.has_language_module: + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) + + if is_last_stage: + llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if infra.pg_collections else None + for key in losses_reduced[0].keys(): + val = [x[key].view(-1) for x in losses_reduced] + if val[0].numel() == 2: + val = torch.vstack(val).sum(dim=0) + if llm_pg is not None and llm_pg.dp_cp is not None: + torch.distributed.all_reduce(val, group=llm_pg.dp_cp) + loss_dict[key] = val[0] / val[1] + elif val[0].numel() == 1: + loss_dict[key] = torch.cat(val).mean() + else: + raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}") + + # Broadcast loss_dict to all ranks (the last rank is the logging rank for + # W&B/TensorBoard). Use broadcast_object_list from the source rank so every + # rank ends up with the same dict — no fragile P2P or GPU-side pickle needed. + last_rank = dist.get_world_size() - 1 + my_rank = dist.get_rank() + + # All ranks agree on which rank holds the loss (pick highest rank with data). + has_loss = 1 if loss_dict else 0 + source_tensor = torch.tensor([my_rank if has_loss else -1], dtype=torch.int32, device="cuda") + torch.distributed.all_reduce(source_tensor, op=torch.distributed.ReduceOp.MAX) + source_rank = int(source_tensor.item()) + + # Only broadcast if the source and logging rank differ and a valid source exists. + if source_rank >= 0 and source_rank != last_rank: + obj = [loss_dict if my_rank == source_rank else None] + torch.distributed.broadcast_object_list(obj, src=source_rank) + if my_rank == last_rank: + received = obj[0] or {} + # Tensors inside the received dict carry the source rank's CUDA device; + # move them to this rank's device so training_log arithmetic works. + loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()} + + return loss_dict, skipped_iter, grad_norm, num_zeros_in_grad + + +def train_mimo( + forward_step_func: Callable, + model: "MimoModel", + optimizer: "MimoOptimizer", + schedulers: Dict[str, "OptimizerParamScheduler"], + train_data_iterator: Iterator, + valid_data_iterator: Optional[Iterator], + global_state: GlobalState, + mimo_infra: "MimoModelInfra", + multimodule_communicator: "MultiModulePipelineCommunicator", + multimodule_pg_collection: Optional["MultiModuleProcessGroupCollection"] = None, + module_to_grid_tuple: Optional[List] = None, +) -> None: + """Main MIMO training loop. + + Key differences from standard train(): + - Uses MultiModuleProcessGroupCollection for the schedule + - Uses forward_backward_pipelining_without_interleaving with multimodule support + - Uses zero_grad_buffer_for_multimodule() for gradient clearing + - Uses MimoOptimizer for coordinated gradient clipping with global norm + + Reuses from existing Bridge training: + - GlobalState for timers, config, train_state + - training_log() for metrics reporting + - handle_profiling_step() and handle_profiling_stop() for profiler lifecycle + - save_checkpoint() with MimoOptimizer for checkpointing + - evaluate_and_print_results() for validation with multimodule support + - maybe_finalize_async_save() for async checkpoint finalization + + Args: + forward_step_func: Forward step function. + model: MimoModel instance. + optimizer: MimoOptimizer managing per-module optimizers. + schedulers: Per-module learning rate schedulers {module_name: scheduler}. + train_data_iterator: Training data iterator. + valid_data_iterator: Validation data iterator (optional). + global_state: GlobalState containing timers, config, train_state. + mimo_infra: MimoModelInfra with grids, topology, pg_collections. + multimodule_communicator: MultiModulePipelineCommunicator for P2P. + multimodule_pg_collection: Pre-built PG collection for the pipeline schedule. + If None, built from mimo_infra. + module_to_grid_tuple: Pre-built (module, grid) pairs for gradient ops. + If None, built from model and mimo_infra. + """ + timers = global_state.timers + train_state = global_state.train_state + cfg = global_state.cfg + + # Get training config + train_config = cfg.train + num_microbatches = get_num_microbatches() + seq_length = cfg.dataset.seq_length + micro_batch_size = train_config.micro_batch_size + + # Prepare forward step function with GlobalState injection + wrapped_forward_step_func = prepare_forward_step_func(forward_step_func, global_state) + + # Use pre-built objects from setup_mimo if provided, otherwise build them. + if module_to_grid_tuple is None: + module_to_grid_tuple = get_module_to_grid_tuple(model, mimo_infra) + if multimodule_pg_collection is None: + multimodule_pg_collection = build_pg_collection_for_schedule(mimo_infra) + + # Guard against list fallback - MIMO training requires MultiModuleProcessGroupCollection + if isinstance(multimodule_pg_collection, list): + raise RuntimeError( + "MultiModuleProcessGroupCollection is required for MIMO training. " + "The list-based fallback is not supported. Ensure Megatron-LM PR 3212 is available." + ) + + # Use rank-local module PG for logging reductions to avoid global MPU fallback. + # NOTE: In non-colocated MIMO each rank participates in exactly one module, so + # "first non-None" unambiguously selects that module's PG. For colocated MIMO + # (where a rank participates in multiple modules), this selection must be + # replaced with per-module logging or an explicit module-aware reduction strategy. + local_pg_collection = next((pg for pg in mimo_infra.pg_collections.values() if pg is not None), None) + if local_pg_collection is None: + raise RuntimeError( + "No local ProcessGroupCollection found for this rank. " + "Ensure rank participation is correctly configured in MIMO infrastructure." + ) + + # Configure gradient hooks on model config + model_config = get_model_config(model) + + # Bind custom parameters via partial(), leaving schedule-provided args unbound + model_config.no_sync_func = partial(multimodule_no_sync, module_to_grid_tuple=module_to_grid_tuple) + + model_config.finalize_model_grads_func = partial( + finalize_model_grads_multimodule, + infra=mimo_infra, + module_to_grid_tuple=module_to_grid_tuple, + ) + + # Optional: Set grad_scale_func from MimoOptimizer + if optimizer is not None and hasattr(optimizer, "scale_loss"): + model_config.grad_scale_func = optimizer.scale_loss + + # Validation: variable_seq_lengths should already be True (set by MimoModelProvider) + assert model_config.variable_seq_lengths, ( + "variable_seq_lengths must be True for MIMO training. " + "This should be set by MimoModelProvider.provide_distributed_model()." + ) + + # Initialize tracking variables + total_loss_dict = {} + history_wct = [] + report_memory_flag = True + + # Get first scheduler for checkpoint saving. + # All modules share the same LR schedule, so first scheduler state is representative. + first_scheduler = next(iter(schedulers.values()), None) if schedulers else None + + # Profiler setup (mirrors train.py behavior) + prof = None + nsys_nvtx_context = None + prof_config = cfg.profiling + if prof_config and should_profile_rank(prof_config, dist.get_rank()): + if prof_config.use_pytorch_profiler: + prof = initialize_pytorch_profiler(prof_config, cfg.logger.tensorboard_dir) + prof.start() + + logger.info(f"Rank {dist.get_rank()}: Starting MIMO training loop") + + # Main training loop + timers("interval-time", log_level=0).start(barrier=True) + + while train_state.step < train_config.train_iters: + # Handle profiling + nsys_ctx = handle_profiling_step( + prof_config, + train_state.step, + dist.get_rank(), + prof, + ) + if nsys_ctx is not None: + nsys_nvtx_context = nsys_ctx + + # Start iteration timer + timers("iteration-time", log_level=0).start(barrier=False) + + # Run single training step + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step_mimo( + forward_step_func=wrapped_forward_step_func, + data_iterator=train_data_iterator, + model=model, + optimizer=optimizer, + schedulers=schedulers, + global_state=global_state, + multimodule_communicator=multimodule_communicator, + multimodule_pg_collection=multimodule_pg_collection, + infra=mimo_infra, + module_to_grid_tuple=module_to_grid_tuple, + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + ) + + # Stop iteration timer + timers("iteration-time").stop(barrier=False) + iteration_time = timers("iteration-time").elapsed(reset=True, barrier=False) + history_wct.append(iteration_time) + + # Update training state + train_state.step += 1 + train_state.consumed_train_samples += micro_batch_size * num_microbatches * cfg.data_parallel_size + + # Get learning rate from first scheduler + learning_rate = None + if schedulers: + sched = next(iter(schedulers.values())) + if sched is not None: + learning_rate = sched.get_lr(sched.optimizer.param_groups[0]) + + # Log training metrics + if not cfg.logger.skip_train_metrics_log: + # Get loss scale from MimoOptimizer + if optimizer is not None and hasattr(optimizer, "get_loss_scale"): + loss_scale = optimizer.get_loss_scale() + if hasattr(loss_scale, "item"): + loss_scale = loss_scale.item() + else: + loss_scale = 1.0 + + report_memory_flag = training_log( + loss_dict=loss_dict, + total_loss_dict=total_loss_dict, + learning_rate=learning_rate, + decoupled_learning_rate=None, + loss_scale=loss_scale, + report_memory_flag=report_memory_flag, + skipped_iter=skipped_iter, + grad_norm=grad_norm, + params_norm=None, + num_zeros_in_grad=num_zeros_in_grad, + config=cfg, + global_state=global_state, + history_wct=history_wct, + model=[model], + pg_collection=local_pg_collection, + ) + + # Log iteration-time directly for MIMO models. + # training_log only logs this inside a hasattr(config.model, "kv_channels") + # block which MIMO models don't satisfy, so we log it here as a workaround. + if cfg.logger.log_timers_to_tensorboard and train_state.step % cfg.logger.log_interval == 0: + writer = global_state.tensorboard_logger + if writer: + writer.add_scalar("iteration-time", iteration_time, train_state.step) + wandb_writer = global_state.wandb_logger + if wandb_writer: + wandb_writer.log({"iteration-time": iteration_time}, train_state.step) + + # Evaluation at specified intervals + if ( + train_config.eval_interval is not None + and train_state.step % train_config.eval_interval == 0 + and valid_data_iterator is not None + ): + timers("evaluate", log_level=0).start(barrier=True) + evaluate_and_print_results( + state=global_state, + prefix=f"iteration {train_state.step}", + forward_step_func=forward_step_func, + data_iterator=valid_data_iterator, + model=[model], + config=cfg, + verbose=False, + write_to_tensorboard=True, + p2p_communicator=multimodule_communicator, + pg_collection=multimodule_pg_collection, + ) + timers("evaluate").stop() + + # Checkpointing at specified intervals + if cfg.checkpoint.save_interval is not None and train_state.step % cfg.checkpoint.save_interval == 0: + timers("save-checkpoint", log_level=0).start(barrier=True) + save_checkpoint( + state=global_state, + model=[model], + optimizer=optimizer, + opt_param_scheduler=first_scheduler, + num_floating_point_operations_so_far=0, # TODO: Add proper FLOPs tracking + ) + timers("save-checkpoint").stop() + + # Finalize any pending async saves (non-blocking during training) + maybe_finalize_async_save( + global_state=global_state, + ckpt_cfg=cfg.checkpoint, + blocking=False, + ) + + # Stop profiling + handle_profiling_stop( + prof_config, + train_state.step, + dist.get_rank(), + prof, + nsys_nvtx_context, + ) + + # Finalize any remaining async saves before exit + maybe_finalize_async_save( + global_state=global_state, + ckpt_cfg=cfg.checkpoint, + blocking=True, + terminate=True, + ) + + timers("interval-time").stop() + + logger.info(f"Rank {dist.get_rank()}: MIMO training completed") diff --git a/src/megatron/bridge/training/utils/train_utils.py b/src/megatron/bridge/training/utils/train_utils.py index ca57a3f272..64ba31151c 100644 --- a/src/megatron/bridge/training/utils/train_utils.py +++ b/src/megatron/bridge/training/utils/train_utils.py @@ -347,6 +347,7 @@ def training_log( global_state: GlobalState, history_wct: list, model: list[MegatronModule], + pg_collection: Optional[Any] = None, log_max_attention_logit: Optional[float] = None, ) -> bool: """Log training stats (losses, learning rate, timings, etc.). @@ -370,6 +371,8 @@ def training_log( global_state: The global training state. history_wct (list): list of elapsed time per each iteration. model (list[MegatronModule]): megatron model state. + pg_collection (Optional[Any]): ProcessGroupCollection to use for logging reductions. + If None, falls back to extracting from model wrappers. log_max_attention_logit (Optional[float]): Maximum attention logit if available, None otherwise. Returns: bool: The updated report_memory_flag. @@ -384,7 +387,7 @@ def training_log( energy_monitor = global_state.energy_monitor logger_config = config.logger train_config = config.train - pg_collection = get_pg_collection(model) + pg_collection = pg_collection or get_pg_collection(model) loggers_exist = writer is not None or wandb_writer is not None or mlflow_logger is not None @@ -687,24 +690,25 @@ def training_log( if comet_logger: comet_logger.log_metrics({"max-attention-logit": log_max_attention_logit}, step=iteration) - if config.model.num_moe_experts is not None: + num_moe_experts = getattr(config.model, "num_moe_experts", None) + if num_moe_experts is not None: moe_loss_scale = 1 / get_num_microbatches() track_names = [] - moe_router_load_balancing_type = config.model.moe_router_load_balancing_type + moe_router_load_balancing_type = getattr(config.model, "moe_router_load_balancing_type", "") if "aux_loss" in moe_router_load_balancing_type: track_names.append("load_balancing_loss") if "seq_aux_loss" in moe_router_load_balancing_type: track_names.append("seq_load_balancing_loss") if "global_aux_loss" in moe_router_load_balancing_type: track_names.append("global_load_balancing_loss") - if config.model.moe_z_loss_coeff is not None: + if getattr(config.model, "moe_z_loss_coeff", None) is not None: track_names.append("z_loss") - if config.model.is_hybrid_model: - layers = config.model.hybrid_layer_pattern.count("E") + if getattr(config.model, "is_hybrid_model", False): + layers = getattr(config.model, "hybrid_override_pattern", "").count("E") else: - layers = config.model.num_layers + layers = getattr(config.model, "num_layers", None) track_moe_metrics( loss_scale=moe_loss_scale, @@ -712,15 +716,15 @@ def training_log( writer=writer, wandb_writer=wandb_writer, total_loss_dict=total_loss_dict, - per_layer_logging=config.model.moe_per_layer_logging, + per_layer_logging=getattr(config.model, "moe_per_layer_logging", False), force_initialize=True, track_names=track_names, num_layers=layers, - moe_layer_freq=config.model.moe_layer_freq, - mtp_num_layers=config.model.mtp_num_layers, + moe_layer_freq=getattr(config.model, "moe_layer_freq", None), + mtp_num_layers=getattr(config.model, "mtp_num_layers", None), pg_collection=pg_collection, ) - if config.model.mtp_num_layers is not None: + if getattr(config.model, "mtp_num_layers", None) is not None: mtp_loss_scale = 1 / get_num_microbatches() MTPLossLoggingHelper.track_mtp_metrics(mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict) @@ -729,14 +733,16 @@ def training_log( elapsed_time_per_iteration = elapsed_time / total_iterations # Calculate GPU utilization - num_flops = num_floating_point_operations(config, batch_size) - per_gpu_tf = num_flops / elapsed_time_per_iteration / get_world_size_safe() / 1e12 - print_rank_0( - f"Step Time : {elapsed_time_per_iteration:.2f}s GPU utilization: {per_gpu_tf:.1f}MODEL_TFLOP/s/GPU" - ) + num_flops = None + if hasattr(config.model, "kv_channels") and hasattr(config.model, "num_attention_heads"): + num_flops = num_floating_point_operations(config, batch_size) + per_gpu_tf = num_flops / elapsed_time_per_iteration / get_world_size_safe() / 1e12 + print_rank_0( + f"Step Time : {elapsed_time_per_iteration:.2f}s GPU utilization: {per_gpu_tf:.1f}MODEL_TFLOP/s/GPU" + ) # throughput - if logger_config.log_throughput_to_tensorboard: + if num_flops is not None and logger_config.log_throughput_to_tensorboard: if writer: writer.add_scalar("throughput/tflops/device", per_gpu_tf, iteration) writer.add_scalar("throughput/tflops", per_gpu_tf * get_world_size_safe(), iteration) @@ -777,7 +783,7 @@ def training_log( log_string += " skipped samples: {:12d} |".format(global_state.train_state.skipped_train_samples) log_string += " elapsed time per iteration (ms): {:.1f} |".format(elapsed_time_per_iteration * 1000.0) - if logger_config.log_throughput: + if num_flops is not None and logger_config.log_throughput: log_string += f" throughput per GPU (TFLOP/s/GPU): {per_gpu_tf:.1f} |" if energy_monitor is not None: diff --git a/tests/unit_tests/data/mimo/test_collate.py b/tests/unit_tests/data/mimo/test_collate.py index 8631ca34f8..6fd71997ff 100644 --- a/tests/unit_tests/data/mimo/test_collate.py +++ b/tests/unit_tests/data/mimo/test_collate.py @@ -17,6 +17,7 @@ def make_sample( return { "input_ids": torch.randint(0, 1000, (seq_length,)), "labels": torch.randint(0, 1000, (seq_length,)), + "loss_mask": torch.ones(seq_length, dtype=torch.float32), "attention_mask": torch.ones(seq_length), "position_ids": torch.arange(seq_length), "modality_inputs": modalities, @@ -34,6 +35,7 @@ def test_basic_collation(self): assert "input_ids" in result assert "labels" in result + assert "loss_mask" in result assert "attention_mask" in result assert "position_ids" in result assert "modality_inputs" in result @@ -48,6 +50,7 @@ def test_batch_dimension(self): assert result["input_ids"].shape == (batch_size, seq_length) assert result["labels"].shape == (batch_size, seq_length) + assert result["loss_mask"].shape == (batch_size, seq_length) assert result["attention_mask"].shape == (batch_size, seq_length) assert result["position_ids"].shape == (batch_size, seq_length) diff --git a/tests/unit_tests/data/mimo/test_dataset.py b/tests/unit_tests/data/mimo/test_dataset.py index 3ede7b8bdf..f0881aae4d 100644 --- a/tests/unit_tests/data/mimo/test_dataset.py +++ b/tests/unit_tests/data/mimo/test_dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Unit tests for MimoDataset.""" import pytest diff --git a/tests/unit_tests/data/mimo/test_dp_utils.py b/tests/unit_tests/data/mimo/test_dp_utils.py index 2da887662f..64650a8172 100644 --- a/tests/unit_tests/data/mimo/test_dp_utils.py +++ b/tests/unit_tests/data/mimo/test_dp_utils.py @@ -4,6 +4,7 @@ import torch.distributed as dist from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info +from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig class FakePG: @@ -35,78 +36,94 @@ def get_pg(self, dims): return self._pgs[tuple(dims)] +def _make_mimo_cfg() -> MimoParallelismConfig: + """Create test MIMO config for heterogeneous deployment.""" + module_parallelisms = { + "vision": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=0), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=4), + } + return MimoParallelismConfig( + module_parallelisms=module_parallelisms, + ) + + def test_get_mimo_dp_info_encoder_first_pp(monkeypatch): """Test heterogeneous mode, rank in encoder module, first PP stage.""" + mimo_cfg = _make_mimo_cfg() monkeypatch.setattr(dist, "get_rank", lambda: 0) grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=2), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } - dp_info = get_mimo_dp_info(grids) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - assert dp_info.loader_module == "vision" - assert dp_info.dp_rank == 0 - assert dp_info.dp_size == 2 - assert dp_info.needs_data is True # First PP stage + assert loader_module == "vision" + assert dp_rank == 0 + assert dp_size == 2 + assert needs_data is True # First PP stage def test_get_mimo_dp_info_encoder_non_first_pp(monkeypatch): """Test heterogeneous mode, rank in encoder module, not first PP stage.""" + mimo_cfg = _make_mimo_cfg() monkeypatch.setattr(dist, "get_rank", lambda: 1) grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=1, pp_size=2), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } - dp_info = get_mimo_dp_info(grids) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - assert dp_info.loader_module == "vision" - assert dp_info.needs_data is False # Not first PP stage + assert loader_module == "vision" + assert needs_data is False # Not first PP stage def test_get_mimo_dp_info_llm_first_pp(monkeypatch): """Test heterogeneous mode, rank in LLM module, first PP stage.""" + mimo_cfg = _make_mimo_cfg() monkeypatch.setattr(dist, "get_rank", lambda: 4) grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=2), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=2), } - dp_info = get_mimo_dp_info(grids) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - assert dp_info.loader_module == "llm" - assert dp_info.needs_data is True # First PP stage + assert loader_module == "language" + assert needs_data is True # First PP stage def test_get_mimo_dp_info_llm_last_pp(monkeypatch): """Test heterogeneous mode, rank in LLM module, last PP stage.""" + mimo_cfg = _make_mimo_cfg() monkeypatch.setattr(dist, "get_rank", lambda: 5) grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=1, dp_size=4, pp_rank=1, pp_size=2), + "language": FakeGrid(4, 4, dp_rank=1, dp_size=4, pp_rank=1, pp_size=2), } - dp_info = get_mimo_dp_info(grids) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - assert dp_info.loader_module == "llm" - assert dp_info.needs_data is True # Last PP stage + assert loader_module == "language" + assert needs_data is True # Last PP stage def test_get_mimo_dp_info_non_participating_rank(monkeypatch): """Test heterogeneous mode, rank not in any module.""" + mimo_cfg = _make_mimo_cfg() monkeypatch.setattr(dist, "get_rank", lambda: 10) # Outside all grids grids = { "vision": FakeGrid(0, 4, dp_rank=0, dp_size=2, pp_rank=0, pp_size=1), - "llm": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), + "language": FakeGrid(4, 4, dp_rank=0, dp_size=4, pp_rank=0, pp_size=1), } - dp_info = get_mimo_dp_info(grids) + dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - assert dp_info.needs_data is False - assert dp_info.loader_module == "llm" # Default to LLM + assert needs_data is False + assert loader_module == "language" # Default to LLM diff --git a/tests/unit_tests/data/mimo/test_hf_provider.py b/tests/unit_tests/data/mimo/test_hf_provider.py index 0a3d18dc64..c5c34b2ba0 100644 --- a/tests/unit_tests/data/mimo/test_hf_provider.py +++ b/tests/unit_tests/data/mimo/test_hf_provider.py @@ -54,8 +54,8 @@ def fake_is_safe_repo(trust_remote_code, hf_path): calls.is_safe_repo += 1 return False - def fake_load_dataset(path, name=None, split=None, trust_remote_code=None): - del path, name, trust_remote_code + def fake_load_dataset(path, name=None, split=None, trust_remote_code=None, data_files=None): + del path, name, trust_remote_code, data_files calls.load_dataset += 1 if split == "validation": raise ValueError("missing split") diff --git a/tests/unit_tests/data/mimo/test_loaders.py b/tests/unit_tests/data/mimo/test_loaders.py index 30dfd962f4..218d5cc8c2 100644 --- a/tests/unit_tests/data/mimo/test_loaders.py +++ b/tests/unit_tests/data/mimo/test_loaders.py @@ -5,16 +5,13 @@ import pytest -from megatron.bridge.data.mimo.dp_utils import MimoDpInfo from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders class FakeMimoModelProvider: - def __init__(self, mimo_parallelism_config): + def __init__(self, mimo_parallelism_config, grids=None): self.mimo_parallelism_config = mimo_parallelism_config - - def get_or_build_infra(self): - return SimpleNamespace(module_to_grid_map={"llm": object()}) + self._grids = grids class FakeProvider: @@ -55,7 +52,8 @@ def test_build_mimo_data_loaders_raises_when_model_not_mimo(monkeypatch): def test_build_mimo_data_loaders_raises_when_parallelism_missing(monkeypatch): _patch_mimo_provider_class(monkeypatch) cfg = SimpleNamespace( - model=FakeMimoModelProvider(mimo_parallelism_config=None), train=SimpleNamespace(micro_batch_size=2) + model=FakeMimoModelProvider(mimo_parallelism_config=None, grids={"llm": object()}), + train=SimpleNamespace(micro_batch_size=2), ) provider = FakeProvider() @@ -65,17 +63,38 @@ def test_build_mimo_data_loaders_raises_when_parallelism_missing(monkeypatch): ) +def test_build_mimo_data_loaders_raises_when_grids_missing(monkeypatch): + _patch_mimo_provider_class(monkeypatch) + cfg = SimpleNamespace( + model=FakeMimoModelProvider(mimo_parallelism_config=object(), grids=None), + train=SimpleNamespace(micro_batch_size=2), + ) + provider = FakeProvider() + + with pytest.raises(ValueError, match="_grids is None"): + build_mimo_data_loaders( + cfg, train_state=None, mimo_provider=provider, train_samples=4, valid_samples=2, test_samples=2 + ) + + def test_build_mimo_data_loaders_happy_path(monkeypatch): _patch_mimo_provider_class(monkeypatch) + fake_grids = {"llm": object()} + fake_parallelism_config = object() cfg = SimpleNamespace( - model=FakeMimoModelProvider(mimo_parallelism_config=object()), + model=FakeMimoModelProvider(mimo_parallelism_config=fake_parallelism_config, grids=fake_grids), train=SimpleNamespace(micro_batch_size=3), ) provider = FakeProvider() monkeypatch.setattr( "megatron.bridge.data.mimo.loaders.get_mimo_dp_info", - lambda grids: MimoDpInfo(dp_rank=1, dp_size=4, needs_data=True, loader_module="llm"), + lambda mimo_cfg, grids: (1, 4, True, "llm"), + ) + + monkeypatch.setattr( + "megatron.bridge.data.mimo.loaders.print_rank_0", + lambda *args, **kwargs: None, ) sampler_calls = [] @@ -144,13 +163,18 @@ def _fake_dataloader( def test_build_mimo_data_loaders_skips_non_data_ranks(monkeypatch): _patch_mimo_provider_class(monkeypatch) cfg = SimpleNamespace( - model=FakeMimoModelProvider(mimo_parallelism_config=object()), + model=FakeMimoModelProvider(mimo_parallelism_config=object(), grids={"llm": object()}), train=SimpleNamespace(micro_batch_size=2), ) provider = FakeProvider() monkeypatch.setattr( "megatron.bridge.data.mimo.loaders.get_mimo_dp_info", - lambda grids: MimoDpInfo(dp_rank=0, dp_size=1, needs_data=False, loader_module="llm"), + lambda mimo_cfg, grids: (0, 1, False, "llm"), + ) + + monkeypatch.setattr( + "megatron.bridge.data.mimo.loaders.print_rank_0", + lambda *args, **kwargs: None, ) train_loader, valid_loader, test_loader = build_mimo_data_loaders( diff --git a/tests/unit_tests/models/mimo/test_llava_provider.py b/tests/unit_tests/models/mimo/test_llava_provider.py index 6ed594dce6..48fe8ce2a8 100644 --- a/tests/unit_tests/models/mimo/test_llava_provider.py +++ b/tests/unit_tests/models/mimo/test_llava_provider.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """Unit tests for LLaVA MIMO Provider.""" from unittest.mock import Mock @@ -224,7 +224,7 @@ def test_can_set_parallelism_config(self): mock_vision_encoder = Mock mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=4), + "language": ModuleParallelismConfig(tensor_model_parallel_size=4), } ) diff --git a/tests/unit_tests/models/mimo/test_mimo_builder.py b/tests/unit_tests/models/mimo/test_mimo_builder.py index 2e072ace82..346806cc57 100644 --- a/tests/unit_tests/models/mimo/test_mimo_builder.py +++ b/tests/unit_tests/models/mimo/test_mimo_builder.py @@ -1,9 +1,9 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. """Unit tests for MIMO builder utilities.""" from unittest.mock import MagicMock, patch -from megatron.bridge.models.mimo.mimo_builder import _default_topology, build_hypercomm_grids +from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig @@ -15,7 +15,7 @@ def test_build_with_single_module(self, mock_grid_class): """Test build_hypercomm_grids with single LLM module.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, context_parallel_size=1, expert_tensor_parallel_size=1, @@ -32,8 +32,8 @@ def test_build_with_single_module(self, mock_grid_class): grids = build_hypercomm_grids(mimo_config) # Should create one grid - assert "llm" in grids - assert grids["llm"] == mock_grid + assert "language" in grids + assert grids["language"] == mock_grid # Check grid was created with correct shape mock_grid_class.assert_called_once() @@ -57,7 +57,7 @@ def test_build_with_multiple_modules(self, mock_grid_class): """Test build_hypercomm_grids with multiple modules.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=4, data_parallel_size=2, rank_offset=0, @@ -82,7 +82,7 @@ def test_build_with_multiple_modules(self, mock_grid_class): grids = build_hypercomm_grids(mimo_config) # Should create three grids - assert "llm" in grids + assert "language" in grids assert "clip_encoder" in grids assert "dino_encoder" in grids assert len(grids) == 3 @@ -95,7 +95,7 @@ def test_build_with_different_parallelism_per_module(self, mock_grid_class): """Test grids with different parallelism configs per module.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=8, pipeline_model_parallel_size=2, data_parallel_size=1, @@ -134,7 +134,7 @@ def test_build_creates_all_dimension_groups(self, mock_grid_class): """Test that all dimension process groups are created.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, context_parallel_size=2, expert_tensor_parallel_size=2, @@ -170,7 +170,7 @@ def test_build_uses_nccl_backend(self, mock_grid_class): """Test that grids use nccl backend.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), } ) @@ -189,7 +189,7 @@ def test_build_with_rank_offsets(self, mock_grid_class): """Test that rank_offset is correctly passed to grids.""" mimo_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, rank_offset=0, @@ -214,93 +214,3 @@ def test_build_with_rank_offsets(self, mock_grid_class): encoder_kwargs = mock_grid_class.call_args_list[1][1] assert encoder_kwargs["rank_offset"] == 4 - - -class TestDefaultTopology: - """Test cases for _default_topology().""" - - def test_topology_with_single_encoder(self): - """Test topology with LLM and one encoder.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - topology = _default_topology(mimo_config) - - # Encoder should point to LLM - assert topology["clip_encoder"] == ["llm"] - # LLM should have no downstream - assert topology["llm"] == [] - - def test_topology_with_multiple_encoders(self): - """Test topology with LLM and multiple encoders.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - "dino_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - "audio_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - topology = _default_topology(mimo_config) - - # All encoders should point to LLM - assert topology["clip_encoder"] == ["llm"] - assert topology["dino_encoder"] == ["llm"] - assert topology["audio_encoder"] == ["llm"] - # LLM should have no downstream - assert topology["llm"] == [] - - def test_topology_with_llm_only(self): - """Test topology with only LLM module.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - topology = _default_topology(mimo_config) - - # LLM should have no downstream - assert topology["llm"] == [] - # Should only have one entry - assert len(topology) == 1 - - def test_topology_structure(self): - """Test that topology has correct structure (dict of lists).""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - "encoder": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - topology = _default_topology(mimo_config) - - # Check it's a dict - assert isinstance(topology, dict) - # Check values are lists - for value in topology.values(): - assert isinstance(value, list) - - def test_topology_all_modules_present(self): - """Test that all modules appear in topology.""" - mimo_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), - "encoder1": ModuleParallelismConfig(tensor_model_parallel_size=2), - "encoder2": ModuleParallelismConfig(tensor_model_parallel_size=2), - } - ) - - topology = _default_topology(mimo_config) - - # All modules should be present in topology - assert "llm" in topology - assert "encoder1" in topology - assert "encoder2" in topology - assert len(topology) == 3 diff --git a/tests/unit_tests/models/mimo/test_mimo_ddp.py b/tests/unit_tests/models/mimo/test_mimo_ddp.py index 64af7b8493..6686e65952 100644 --- a/tests/unit_tests/models/mimo/test_mimo_ddp.py +++ b/tests/unit_tests/models/mimo/test_mimo_ddp.py @@ -3,70 +3,10 @@ from unittest.mock import MagicMock, patch -from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig, ModuleParallelismConfig from megatron.bridge.models.mimo.mimo_ddp import wrap_mimo_model_distributed -class TestIsCurrentRankInGrid: - """Test cases for is_current_rank_in_grid helper.""" - - @patch("torch.distributed.get_rank") - def test_rank_in_grid(self, mock_get_rank): - """Rank within grid range should return True.""" - mock_get_rank.return_value = 2 - - mock_grid = MagicMock() - mock_grid.rank_offset = 0 - mock_grid.size = 4 - - assert is_current_rank_in_grid(mock_grid) is True - - @patch("torch.distributed.get_rank") - def test_rank_at_grid_start(self, mock_get_rank): - """Rank at grid start should return True.""" - mock_get_rank.return_value = 4 - - mock_grid = MagicMock() - mock_grid.rank_offset = 4 - mock_grid.size = 4 - - assert is_current_rank_in_grid(mock_grid) is True - - @patch("torch.distributed.get_rank") - def test_rank_at_grid_end_exclusive(self, mock_get_rank): - """Rank at grid end (exclusive) should return False.""" - mock_get_rank.return_value = 8 - - mock_grid = MagicMock() - mock_grid.rank_offset = 4 - mock_grid.size = 4 - - assert is_current_rank_in_grid(mock_grid) is False - - @patch("torch.distributed.get_rank") - def test_rank_before_grid(self, mock_get_rank): - """Rank before grid range should return False.""" - mock_get_rank.return_value = 2 - - mock_grid = MagicMock() - mock_grid.rank_offset = 4 - mock_grid.size = 4 - - assert is_current_rank_in_grid(mock_grid) is False - - @patch("torch.distributed.get_rank") - def test_rank_after_grid(self, mock_get_rank): - """Rank after grid range should return False.""" - mock_get_rank.return_value = 10 - - mock_grid = MagicMock() - mock_grid.rank_offset = 0 - mock_grid.size = 4 - - assert is_current_rank_in_grid(mock_grid) is False - - class TestWrapMimoModelDistributed: """Test cases for wrap_mimo_model_distributed.""" @@ -124,12 +64,12 @@ def test_wrap_language_model(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, } ) - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) @@ -149,12 +89,12 @@ def test_skip_language_model_non_participating_rank(self, mock_get_rank, mock_dd ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, } ) - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) @@ -173,17 +113,17 @@ def test_wrap_modality_submodules(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, "images": {"tp": 1, "dp": 4}, } ) grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=0, size=4), } pg_collections = { - "llm": MagicMock(), + "language": MagicMock(), "images": MagicMock(), } @@ -192,6 +132,38 @@ def test_wrap_modality_submodules(self, mock_get_rank, mock_ddp): # Should wrap both language model and images submodule assert mock_ddp.call_count == 2 + @patch("megatron.core.distributed.DistributedDataParallel") + @patch("torch.distributed.get_rank") + def test_skip_modality_submodule_no_grid(self, mock_get_rank, mock_ddp): + """Test that modality submodules without grids are skipped.""" + mock_get_rank.return_value = 0 + mock_ddp.return_value = MagicMock() + + mimo_model = self._create_mock_mimo_model(has_language_model=True, modality_names=["images", "audio"]) + ddp_config = MagicMock() + mimo_parallelism_config = self._create_mimo_parallelism_config( + { + "language": {"tp": 2, "dp": 2}, + "images": {"tp": 1, "dp": 4}, + # Note: no "audio" in parallelism config + } + ) + + # Only llm and images have grids + grids = { + "language": self._create_mock_grid(rank_offset=0, size=4), + "images": self._create_mock_grid(rank_offset=0, size=4), + } + pg_collections = { + "language": MagicMock(), + "images": MagicMock(), + } + + wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) + + # Should wrap llm and images, but not audio (no grid) + assert mock_ddp.call_count == 2 + @patch("megatron.core.distributed.DistributedDataParallel") @patch("torch.distributed.get_rank") def test_heterogeneous_different_rank_ranges(self, mock_get_rank, mock_ddp): @@ -205,17 +177,17 @@ def test_heterogeneous_different_rank_ranges(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2, "rank_offset": 0}, + "language": {"tp": 2, "dp": 2, "rank_offset": 0}, "images": {"tp": 2, "dp": 2, "rank_offset": 4}, } ) grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=4, size=4), } pg_collections = { - "llm": None, # Rank 4 doesn't participate in LLM + "language": None, # Rank 4 doesn't participate in LLM "images": MagicMock(), } @@ -237,17 +209,17 @@ def test_no_language_model(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, "images": {"tp": 1, "dp": 4}, } ) grids = { - "llm": self._create_mock_grid(rank_offset=0, size=4), + "language": self._create_mock_grid(rank_offset=0, size=4), "images": self._create_mock_grid(rank_offset=0, size=4), } pg_collections = { - "llm": MagicMock(), + "language": MagicMock(), "images": MagicMock(), } @@ -268,12 +240,12 @@ def test_returns_same_model_instance(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, } ) - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} - pg_collections = {"llm": MagicMock()} + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} + pg_collections = {"language": MagicMock()} result = wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) @@ -295,13 +267,13 @@ def test_ddp_called_with_correct_args(self, mock_get_rank, mock_ddp): ddp_config = MagicMock() mimo_parallelism_config = self._create_mimo_parallelism_config( { - "llm": {"tp": 2, "dp": 2}, + "language": {"tp": 2, "dp": 2}, } ) - grids = {"llm": self._create_mock_grid(rank_offset=0, size=4)} + grids = {"language": self._create_mock_grid(rank_offset=0, size=4)} llm_pg_collection = MagicMock() - pg_collections = {"llm": llm_pg_collection} + pg_collections = {"language": llm_pg_collection} wrap_mimo_model_distributed(mimo_model, ddp_config, mimo_parallelism_config, grids, pg_collections) diff --git a/tests/unit_tests/models/mimo/test_mimo_provider.py b/tests/unit_tests/models/mimo/test_mimo_provider.py index aa27977dc9..6b7bc1c733 100644 --- a/tests/unit_tests/models/mimo/test_mimo_provider.py +++ b/tests/unit_tests/models/mimo/test_mimo_provider.py @@ -3,7 +3,6 @@ from unittest.mock import MagicMock, Mock, patch -import pytest from megatron.core.transformer.spec_utils import ModuleSpec from megatron.bridge.models.mimo import ( @@ -35,7 +34,7 @@ def test_provider_initialization_full(self): modality_spec = ModuleSpec(module=Mock, params={}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2), }, ) @@ -65,7 +64,6 @@ def test_provider_has_mixin_fields(self): assert hasattr(provider, "bf16") assert hasattr(provider, "use_cpu_initialization") assert hasattr(provider, "init_model_with_meta_device") - assert hasattr(provider, "virtual_pipeline_model_parallel_size") # Check defaults assert provider.fp16 is False @@ -93,6 +91,8 @@ def test_provide_returns_model_directly(self, mock_build_grids, mock_mimo_model) # Should not build grids when no parallelism config mock_build_grids.assert_not_called() + config_arg = mock_mimo_model.call_args[0][0] + assert config_arg.module_to_grid_map is None @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") @@ -117,10 +117,10 @@ def test_build_infra_without_parallelism(self, mock_build_grids): infra = provider.build_infra() - # Should return empty infrastructure + # Should return infrastructure with auto-derived topology assert isinstance(infra, MimoModelInfra) assert infra.module_to_grid_map == {} - assert infra.topology == {} + assert infra.topology == {"language": []} assert infra.pg_collections == {} assert infra.participating_modules == [] @@ -131,10 +131,7 @@ def test_build_infra_without_parallelism(self, mock_build_grids): @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_build_infra_with_parallelism( - self, mock_topology, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group - ): + def test_build_infra_with_parallelism(self, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group): """Test build_infra() with parallelism config.""" mock_get_rank.return_value = 0 mock_get_pg_ranks.return_value = [0, 1, 2, 3] @@ -143,7 +140,7 @@ def test_build_infra_with_parallelism( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, ), @@ -155,8 +152,7 @@ def test_build_infra_with_parallelism( mock_grid.rank_offset = 0 mock_grid.size = 4 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -170,18 +166,15 @@ def test_build_infra_with_parallelism( # Should return populated infrastructure assert isinstance(infra, MimoModelInfra) - assert "llm" in infra.module_to_grid_map - assert "llm" in infra.pg_collections - assert "llm" in infra.participating_modules + assert "language" in infra.module_to_grid_map + assert "language" in infra.pg_collections + assert "language" in infra.participating_modules @patch("torch.distributed.new_group") @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_build_infra_is_idempotent( - self, mock_topology, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group - ): + def test_build_infra_is_idempotent(self, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group): """Test build_infra() can be called multiple times.""" mock_get_rank.return_value = 0 mock_get_pg_ranks.return_value = [0, 1] @@ -190,7 +183,7 @@ def test_build_infra_is_idempotent( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), }, ) @@ -198,8 +191,7 @@ def test_build_infra_is_idempotent( mock_grid.rank_offset = 0 mock_grid.size = 2 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -213,52 +205,13 @@ def test_build_infra_is_idempotent( # Should return equivalent results (not cached, but same structure) assert infra1.participating_modules == infra2.participating_modules - @patch("torch.distributed.new_group") - @patch("torch.distributed.get_process_group_ranks") - @patch("torch.distributed.get_rank") - @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") - def test_get_or_build_infra_caches_result( - self, mock_topology, mock_build_grids, mock_get_rank, mock_get_pg_ranks, mock_new_group - ): - """Test get_or_build_infra() builds once and reuses cached infra.""" - mock_get_rank.return_value = 0 - mock_get_pg_ranks.return_value = [0, 1] - mock_new_group.return_value = MagicMock() - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - - mimo_parallelism_config = MimoParallelismConfig( - module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1, rank_offset=0), - }, - ) - - mock_grid = MagicMock() - mock_grid.rank_offset = 0 - mock_grid.size = 2 - mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mimo_parallelism_config, - ) - - infra1 = provider.get_or_build_infra() - infra2 = provider.get_or_build_infra() - - assert infra1 is infra2 - mock_build_grids.assert_called_once_with(mimo_parallelism_config) - @patch("torch.distributed.new_group") @patch("torch.distributed.get_process_group_ranks") @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") def test_provide_with_parallelism( - self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group + self, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group ): """Test provide() with parallelism config.""" mock_get_rank.return_value = 0 @@ -268,7 +221,7 @@ def test_provide_with_parallelism( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=2, data_parallel_size=2, ), @@ -279,8 +232,7 @@ def test_provide_with_parallelism( mock_grid.rank_offset = 0 mock_grid.size = 4 mock_grid.get_pg.return_value = MagicMock() - mock_build_grids.return_value = {"llm": mock_grid} - mock_topology.return_value = {"llm": []} + mock_build_grids.return_value = {"language": mock_grid} provider = MimoModelProvider( language_model_spec=language_spec, @@ -294,11 +246,13 @@ def test_provide_with_parallelism( # Should return model directly assert model == mock_model_instance + config_arg = mock_mimo_model.call_args[0][0] + assert config_arg.module_to_grid_map == {"language": mock_grid} # Infrastructure should be available via build_infra() infra = provider.build_infra() - assert "llm" in infra.module_to_grid_map - assert "llm" in infra.pg_collections + assert "language" in infra.module_to_grid_map + assert "language" in infra.pg_collections def test_inject_pg_collection_into_language_spec(self): """Test that pg_collection is injected into language specs.""" @@ -359,9 +313,8 @@ def test_freezing_language_model(self, mock_mimo_model): @patch("torch.distributed.get_rank") @patch("megatron.bridge.models.mimo.mimo_provider.MimoModel") @patch("megatron.bridge.models.mimo.mimo_provider.build_hypercomm_grids") - @patch("megatron.bridge.models.mimo.mimo_provider._default_topology") def test_per_encoder_parallelism( - self, mock_topology, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group + self, mock_build_grids, mock_mimo_model, mock_get_rank, mock_get_pg_ranks, mock_new_group ): """Test per-encoder parallelism with different TP per encoder.""" mock_get_rank.return_value = 0 @@ -373,7 +326,7 @@ def test_per_encoder_parallelism( mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=8, data_parallel_size=1), + "language": ModuleParallelismConfig(tensor_model_parallel_size=8, data_parallel_size=1), "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=1), "dino_encoder": ModuleParallelismConfig(tensor_model_parallel_size=4, data_parallel_size=1), }, @@ -396,17 +349,11 @@ def test_per_encoder_parallelism( dino_grid.get_pg.return_value = MagicMock() mock_build_grids.return_value = { - "llm": llm_grid, + "language": llm_grid, "clip_encoder": clip_grid, "dino_encoder": dino_grid, } - mock_topology.return_value = { - "clip_encoder": ["llm"], - "dino_encoder": ["llm"], - "llm": [], - } - provider = MimoModelProvider( language_model_spec=language_spec, modality_submodules_spec={ @@ -430,21 +377,24 @@ def test_per_encoder_parallelism( mock_build_grids.assert_called_with(mimo_parallelism_config) # Should have pg_collections for all modules - assert "llm" in infra.pg_collections + assert "language" in infra.pg_collections assert "clip_encoder" in infra.pg_collections assert "dino_encoder" in infra.pg_collections # Should return model directly assert model == mock_model_instance - def test_initialize_model_parallel_is_noop(self): - """Test that initialize_model_parallel() is a no-op for MIMO.""" + def test_initialize_model_parallel_raises(self): + """Test that initialize_model_parallel() raises NotImplementedError for MIMO.""" language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) provider = MimoModelProvider(language_model_spec=language_spec) - # Should not raise, should be a no-op - provider.initialize_model_parallel(seed=42) - provider.initialize_model_parallel() + import pytest + + with pytest.raises(NotImplementedError, match="MIMO does not use global model parallelism"): + provider.initialize_model_parallel(seed=42) + with pytest.raises(NotImplementedError, match="MIMO does not use global model parallelism"): + provider.initialize_model_parallel() @patch("megatron.core.transformer.module.Float16Module") @patch("megatron.bridge.models.mimo.mimo_provider.get_model_config") @@ -478,32 +428,16 @@ def test_provide_distributed_model_sets_variable_seq_lengths( # Should have set variable_seq_lengths=True assert mock_config.variable_seq_lengths is True - def test_provide_distributed_model_propagates_finalize_error(self): - """Test provider surfaces finalize() errors from MimoParallelismConfig.""" - language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) - mock_parallelism_config = Mock() - mock_parallelism_config.finalize.side_effect = ValueError("invalid mimo config") - - provider = MimoModelProvider( - language_model_spec=language_spec, - mimo_parallelism_config=mock_parallelism_config, - ) - - with pytest.raises(ValueError, match="invalid mimo config"): - provider.provide_distributed_model(wrap_with_ddp=False) - - mock_parallelism_config.finalize.assert_called_once() - class TestMimoModelInfra: """Test cases for MimoModelInfra dataclass.""" def test_infra_initialization(self): """Test infrastructure dataclass initializes correctly.""" - grids = {"llm": MagicMock()} - topology = {"llm": []} - pg_collections = {"llm": MagicMock()} - participating = ["llm"] + grids = {"language": MagicMock()} + topology = {"language": []} + pg_collections = {"language": MagicMock()} + participating = ["language"] infra = MimoModelInfra( module_to_grid_map=grids, @@ -526,14 +460,14 @@ class TestEmbeddingGroupHelpers: def test_populate_embedding_groups_single_pp_rank(self, mock_get_ranks, mock_new_group): """Test embedding groups with single PP rank (PP=1).""" from megatron.bridge.models.mimo.mimo_builder import ( - create_embedding_and_position_groups, + populate_embedding_and_position_groups, ) mock_pp_group = MagicMock() mock_get_ranks.return_value = [0] # Single PP rank mock_new_group.return_value = MagicMock() - create_embedding_and_position_groups(mock_pp_group) + populate_embedding_and_position_groups(mock_pp_group) # Should create groups for both position and word embeddings assert mock_new_group.call_count == 2 @@ -547,14 +481,14 @@ def test_populate_embedding_groups_single_pp_rank(self, mock_get_ranks, mock_new def test_populate_embedding_groups_multiple_pp_ranks(self, mock_get_ranks, mock_new_group): """Test embedding groups with multiple PP ranks (PP>1).""" from megatron.bridge.models.mimo.mimo_builder import ( - create_embedding_and_position_groups, + populate_embedding_and_position_groups, ) mock_pp_group = MagicMock() mock_get_ranks.return_value = [0, 4, 8, 12] # PP=4 mock_new_group.return_value = MagicMock() - create_embedding_and_position_groups(mock_pp_group) + populate_embedding_and_position_groups(mock_pp_group) # Should create two groups assert mock_new_group.call_count == 2 @@ -567,21 +501,81 @@ def test_populate_embedding_groups_multiple_pp_ranks(self, mock_get_ranks, mock_ def test_populate_embedding_groups_none_pp_group(self): """Test embedding groups with None PP group.""" from megatron.bridge.models.mimo.mimo_builder import ( - create_embedding_and_position_groups, + populate_embedding_and_position_groups, ) - pos_embd_pg, embd_pg = create_embedding_and_position_groups(None) + pos_embd_pg, embd_pg = populate_embedding_and_position_groups(None) assert pos_embd_pg is None assert embd_pg is None + @patch("torch.distributed.get_process_group_ranks") + @patch("torch.distributed.get_rank") + def test_is_pp_first_stage_true(self, mock_get_rank, mock_get_ranks): + """Test is_pp_first_stage returns True for first stage.""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_first_stage + + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0, 4, 8, 12] + mock_get_rank.return_value = 0 + + assert is_pp_first_stage(mock_pp_group) is True + + @patch("torch.distributed.get_process_group_ranks") + @patch("torch.distributed.get_rank") + def test_is_pp_first_stage_false(self, mock_get_rank, mock_get_ranks): + """Test is_pp_first_stage returns False for non-first stage.""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_first_stage + + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0, 4, 8, 12] + mock_get_rank.return_value = 4 + + assert is_pp_first_stage(mock_pp_group) is False + + def test_is_pp_first_stage_none_group(self): + """Test is_pp_first_stage returns True for None group (no PP).""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_first_stage + + assert is_pp_first_stage(None) is True + + @patch("torch.distributed.get_process_group_ranks") + @patch("torch.distributed.get_rank") + def test_is_pp_last_stage_true(self, mock_get_rank, mock_get_ranks): + """Test is_pp_last_stage returns True for last stage.""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_last_stage + + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0, 4, 8, 12] + mock_get_rank.return_value = 12 + + assert is_pp_last_stage(mock_pp_group) is True + + @patch("torch.distributed.get_process_group_ranks") + @patch("torch.distributed.get_rank") + def test_is_pp_last_stage_false(self, mock_get_rank, mock_get_ranks): + """Test is_pp_last_stage returns False for non-last stage.""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_last_stage + + mock_pp_group = MagicMock() + mock_get_ranks.return_value = [0, 4, 8, 12] + mock_get_rank.return_value = 4 + + assert is_pp_last_stage(mock_pp_group) is False + + def test_is_pp_last_stage_none_group(self): + """Test is_pp_last_stage returns True for None group (no PP).""" + from megatron.bridge.models.mimo.mimo_builder import is_pp_last_stage + + assert is_pp_last_stage(None) is True + class TestProcessGroupCollectionWithEmbeddingGroups: """Test that ProcessGroupCollection includes embedding groups.""" @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") - @patch("megatron.bridge.models.mimo.mimo_provider.create_embedding_and_position_groups") + @patch("megatron.bridge.models.mimo.mimo_provider.populate_embedding_and_position_groups") @patch("torch.distributed.get_rank") def test_pg_collection_includes_embedding_groups_first_stage( self, mock_get_rank, mock_populate, mock_is_first, mock_is_last @@ -597,7 +591,7 @@ def test_pg_collection_includes_embedding_groups_first_stage( language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), }, ) @@ -612,15 +606,15 @@ def test_pg_collection_includes_embedding_groups_first_stage( mimo_parallelism_config=mimo_parallelism_config, ) - pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) # First stage should have pos_embd but not embd (not last stage) - assert pg_collections["llm"].pos_embd == mock_pos_embd - assert pg_collections["llm"].embd == mock_embd # First stage gets embd too + assert pg_collections["language"].pos_embd == mock_pos_embd + assert pg_collections["language"].embd == mock_embd # First stage gets embd too @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") - @patch("megatron.bridge.models.mimo.mimo_provider.create_embedding_and_position_groups") + @patch("megatron.bridge.models.mimo.mimo_provider.populate_embedding_and_position_groups") @patch("torch.distributed.get_rank") def test_pg_collection_middle_stage_no_embedding_groups( self, mock_get_rank, mock_populate, mock_is_first, mock_is_last @@ -636,7 +630,7 @@ def test_pg_collection_middle_stage_no_embedding_groups( language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) mimo_parallelism_config = MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), }, ) @@ -651,8 +645,68 @@ def test_pg_collection_middle_stage_no_embedding_groups( mimo_parallelism_config=mimo_parallelism_config, ) - pg_collections = provider._get_pg_collections_from_grids({"llm": mock_grid}) + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) # Middle stage should have neither embedding group - assert pg_collections["llm"].pos_embd is None - assert pg_collections["llm"].embd is None + assert pg_collections["language"].pos_embd is None + assert pg_collections["language"].embd is None + + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_last_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.is_pp_first_stage") + @patch("megatron.bridge.models.mimo.mimo_provider.populate_embedding_and_position_groups") + @patch("torch.distributed.get_rank") + def test_pg_collection_includes_composite_groups(self, mock_get_rank, mock_populate, mock_is_first, mock_is_last): + """Test that pg_collection includes mp, tp_ep_pp, and expt_dp composite groups.""" + mock_get_rank.return_value = 0 + mock_populate.return_value = (MagicMock(), MagicMock()) + mock_is_first.return_value = True + mock_is_last.return_value = True + + language_spec = ModuleSpec(module=Mock, params={"config": Mock()}) + mimo_parallelism_config = MimoParallelismConfig( + module_parallelisms={ + "language": ModuleParallelismConfig(tensor_model_parallel_size=2, data_parallel_size=2), + }, + ) + + mock_tp = MagicMock(name="tp_pg") + mock_dp = MagicMock(name="dp_pg") + mock_pp = MagicMock(name="pp_pg") + mock_cp = MagicMock(name="cp_pg") + mock_ep = MagicMock(name="ep_pg") + mock_dp_cp = MagicMock(name="dp_cp_pg") + mock_mp = MagicMock(name="mp_pg") + mock_tp_ep_pp = MagicMock(name="tp_ep_pp_pg") + + pg_map = { + ("tp",): mock_tp, + ("dp",): mock_dp, + ("pp",): mock_pp, + ("cp",): mock_cp, + ("ep",): mock_ep, + ("dp", "cp"): mock_dp_cp, + ("tp", "pp"): mock_mp, + ("tp", "ep", "pp"): mock_tp_ep_pp, + } + + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + mock_grid.get_pg.side_effect = lambda dims: pg_map[tuple(dims)] + + provider = MimoModelProvider( + language_model_spec=language_spec, + mimo_parallelism_config=mimo_parallelism_config, + ) + + pg_collections = provider._get_pg_collections_from_grids({"language": mock_grid}) + + pgc = pg_collections["language"] + assert pgc.tp == mock_tp + assert pgc.dp == mock_dp + assert pgc.pp == mock_pp + assert pgc.cp == mock_cp + assert pgc.ep == mock_ep + assert pgc.dp_cp == mock_dp_cp + assert pgc.mp == mock_mp + assert pgc.tp_ep_pp == mock_tp_ep_pp diff --git a/tests/unit_tests/training/mimo/__init__.py b/tests/unit_tests/training/mimo/__init__.py new file mode 100644 index 0000000000..dd58698a87 --- /dev/null +++ b/tests/unit_tests/training/mimo/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MIMO training modules.""" diff --git a/tests/unit_tests/training/mimo/test_mimo_config.py b/tests/unit_tests/training/mimo/test_mimo_config.py index 3770043d48..3548c0853c 100644 --- a/tests/unit_tests/training/mimo/test_mimo_config.py +++ b/tests/unit_tests/training/mimo/test_mimo_config.py @@ -22,7 +22,7 @@ def test_mimo_heterogeneous_rank_offset_overlap(): """Test that overlapping rank ranges are detected in heterogeneous deployment.""" module_parallelisms = { "encoder": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=0), - "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=2), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=2), } mimo_parallelism_config = MimoParallelismConfig( module_parallelisms=module_parallelisms, @@ -36,7 +36,7 @@ def test_mimo_heterogeneous_valid_contiguous(): # Note: encoder DP must be >= LLM DP for embedding alignment module_parallelisms = { "encoder": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=4, rank_offset=0), - "llm": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=4), + "language": ModuleParallelismConfig(tensor_model_parallel_size=1, data_parallel_size=2, rank_offset=4), } mimo_parallelism_config = MimoParallelismConfig( module_parallelisms=module_parallelisms, diff --git a/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py b/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py new file mode 100644 index 0000000000..1b58e1ea1a --- /dev/null +++ b/tests/unit_tests/training/mimo/test_mimo_parallel_utils.py @@ -0,0 +1,283 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MIMO parallel utilities.""" + +from unittest.mock import MagicMock, patch + +import pytest + + +class TestIsCurrentRankInGrid: + """Test cases for is_current_rank_in_grid().""" + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") + def test_rank_in_grid(self, mock_dist): + """Test rank within grid range returns True.""" + from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid + + mock_dist.get_rank.return_value = 2 + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is True + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") + def test_rank_not_in_grid(self, mock_dist): + """Test rank outside grid range returns False.""" + from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid + + mock_dist.get_rank.return_value = 5 + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is False + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") + def test_rank_at_grid_boundary(self, mock_dist): + """Test rank at grid boundary.""" + from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid + + mock_grid = MagicMock() + mock_grid.rank_offset = 4 + mock_grid.size = 4 + + # At start boundary (inclusive) + mock_dist.get_rank.return_value = 4 + assert is_current_rank_in_grid(mock_grid) is True + + # At end boundary (exclusive) + mock_dist.get_rank.return_value = 8 + assert is_current_rank_in_grid(mock_grid) is False + + @patch("megatron.bridge.training.mimo_parallel_utils.dist") + def test_rank_before_grid(self, mock_dist): + """Test rank before grid range returns False.""" + from megatron.bridge.training.mimo_parallel_utils import is_current_rank_in_grid + + mock_dist.get_rank.return_value = 2 + mock_grid = MagicMock() + mock_grid.rank_offset = 4 + mock_grid.size = 4 + + assert is_current_rank_in_grid(mock_grid) is False + + +class TestValidateNoStubRanks: + """Test cases for validate_no_stub_ranks().""" + + def test_all_ranks_participate(self): + """Test validation passes when all ranks participate.""" + from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks + + mock_grid1 = MagicMock() + mock_grid1.rank_offset = 0 + mock_grid1.size = 4 + + mock_grid2 = MagicMock() + mock_grid2.rank_offset = 4 + mock_grid2.size = 4 + + module_to_grid_map = { + "encoder": mock_grid1, + "language": mock_grid2, + } + + # Should not raise + validate_no_stub_ranks(module_to_grid_map, world_size=8) + + def test_stub_ranks_detected(self): + """Test validation fails when stub ranks exist.""" + from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks + + mock_grid = MagicMock() + mock_grid.rank_offset = 0 + mock_grid.size = 4 + + module_to_grid_map = {"language": mock_grid} + + with pytest.raises(ValueError, match="do not participate in any module"): + validate_no_stub_ranks(module_to_grid_map, world_size=8) + + def test_overlapping_grids(self): + """Test validation with overlapping grids (colocated case).""" + from megatron.bridge.training.mimo_parallel_utils import validate_no_stub_ranks + + mock_grid1 = MagicMock() + mock_grid1.rank_offset = 0 + mock_grid1.size = 4 + + mock_grid2 = MagicMock() + mock_grid2.rank_offset = 0 + mock_grid2.size = 4 + + module_to_grid_map = { + "encoder": mock_grid1, + "language": mock_grid2, + } + + # Should not raise (all 4 ranks participate) + validate_no_stub_ranks(module_to_grid_map, world_size=4) + + +class TestValidateDataLoaderContract: + """Test cases for validate_data_loader_contract().""" + + def test_valid_configuration(self): + """Test validation passes for valid configuration.""" + from megatron.bridge.training.mimo_parallel_utils import validate_data_loader_contract + + mock_grid = MagicMock() + mock_grid.get_pg_size.return_value = 2 # DP size = 2 + + mock_infra = MagicMock() + mock_infra.module_to_grid_map = {"language": mock_grid} + + # global_batch=16, dp=2, per_dp_batch=8, microbatches=4, micro_batch_size=2 + # 4 * 2 = 8 == 16 / 2 ✓ + validate_data_loader_contract( + infra=mock_infra, + global_batch_size=16, + micro_batch_size=2, + num_microbatches=4, + ) + + def test_batch_not_divisible_by_dp(self): + """Test validation fails when batch not divisible by DP size.""" + from megatron.bridge.training.mimo_parallel_utils import validate_data_loader_contract + + mock_grid = MagicMock() + mock_grid.get_pg_size.return_value = 3 # DP size = 3 + + mock_infra = MagicMock() + mock_infra.module_to_grid_map = {"language": mock_grid} + + with pytest.raises(ValueError, match="not divisible"): + validate_data_loader_contract( + infra=mock_infra, + global_batch_size=16, + micro_batch_size=2, + num_microbatches=4, + ) + + +class TestBuildPgCollectionForSchedule: + """Test cases for build_pg_collection_for_schedule().""" + + def test_fallback_to_list(self): + """Test fallback to list when MultiModuleProcessGroupCollection not available.""" + from megatron.bridge.training.mimo_parallel_utils import build_pg_collection_for_schedule + + mock_pg1 = MagicMock() + mock_pg2 = MagicMock() + + mock_infra = MagicMock() + mock_infra.pg_collections = { + "encoder": mock_pg1, + "language": mock_pg2, + } + + # This will likely fall back to list since import may fail in test env + result = build_pg_collection_for_schedule(mock_infra) + + # Should be either a list or MultiModuleProcessGroupCollection + assert result is not None + + def test_filters_none_pg_collections(self): + """Test that None pg_collections are filtered out.""" + from megatron.bridge.training.mimo_parallel_utils import build_pg_collection_for_schedule + + mock_pg = MagicMock() + + mock_infra = MagicMock() + mock_infra.pg_collections = { + "encoder": None, # Non-participating module + "language": mock_pg, + } + + result = build_pg_collection_for_schedule(mock_infra) + + # Should filter out None values + if isinstance(result, list): + assert len(result) == 1 + assert mock_pg in result + + +class TestMultimoduleNoSync: + """Test cases for multimodule_no_sync context manager.""" + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") + def test_enters_and_exits_contexts(self, mock_in_grid): + """Test that no_sync contexts are properly entered and exited.""" + from megatron.bridge.training.mimo_parallel_utils import multimodule_no_sync + + mock_in_grid.return_value = True + + mock_module = MagicMock() + mock_context = MagicMock() + mock_module.no_sync.return_value = mock_context + + mock_grid = MagicMock() + + module_to_grid_tuple = [(mock_module, mock_grid)] + + with multimodule_no_sync(module_to_grid_tuple=module_to_grid_tuple): + pass + + # Verify context was entered and exited + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") + def test_skips_non_participating_modules(self, mock_in_grid): + """Test that non-participating modules are skipped.""" + from megatron.bridge.training.mimo_parallel_utils import multimodule_no_sync + + mock_in_grid.return_value = False # Not participating + + mock_module = MagicMock() + mock_grid = MagicMock() + + module_to_grid_tuple = [(mock_module, mock_grid)] + + with multimodule_no_sync(module_to_grid_tuple=module_to_grid_tuple): + pass + + # no_sync should not be called + mock_module.no_sync.assert_not_called() + + +class TestZeroGradBufferForMultimodule: + """Test cases for zero_grad_buffer_for_multimodule().""" + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") + def test_zeros_grad_buffers(self, mock_in_grid): + """Test gradient buffers are zeroed for participating modules.""" + from megatron.bridge.training.mimo_parallel_utils import zero_grad_buffer_for_multimodule + + mock_in_grid.return_value = True + + mock_module = MagicMock() + mock_grid = MagicMock() + + module_to_grid_tuple = [(mock_module, mock_grid)] + + zero_grad_buffer_for_multimodule(module_to_grid_tuple) + + mock_module.zero_grad_buffer.assert_called_once() + + @patch("megatron.bridge.training.mimo_parallel_utils.is_current_rank_in_grid") + def test_skips_non_participating(self, mock_in_grid): + """Test non-participating modules are skipped.""" + from megatron.bridge.training.mimo_parallel_utils import zero_grad_buffer_for_multimodule + + mock_in_grid.return_value = False + + mock_module = MagicMock() + mock_grid = MagicMock() + + module_to_grid_tuple = [(mock_module, mock_grid)] + + zero_grad_buffer_for_multimodule(module_to_grid_tuple) + + mock_module.zero_grad_buffer.assert_not_called() diff --git a/tests/unit_tests/training/mimo/test_mimo_step.py b/tests/unit_tests/training/mimo/test_mimo_step.py new file mode 100644 index 0000000000..dc26ff26d2 --- /dev/null +++ b/tests/unit_tests/training/mimo/test_mimo_step.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MIMO forward step functions.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + + +class TestLossFunc: + """Test cases for loss_func().""" + + def test_loss_computation(self): + """Test loss is computed correctly with mask.""" + from megatron.bridge.training.mimo_step import loss_func + + # Create test data + output_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) + loss_mask = torch.tensor([1.0, 1.0, 0.0, 1.0]) # Mask out 3rd element + + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) + + # Expected: (1.0*1 + 2.0*1 + 3.0*0 + 4.0*1) = 7.0 + assert total_loss.item() == 7.0 + # Expected tokens: 3 (sum of mask) + assert num_tokens.item() == 3 + # Check metrics dict structure + assert "lm loss" in metrics + + def test_loss_with_all_ones_mask(self): + """Test loss with all-ones mask.""" + from megatron.bridge.training.mimo_step import loss_func + + output_tensor = torch.tensor([1.0, 2.0, 3.0]) + loss_mask = torch.ones(3) + + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) + + assert total_loss.item() == 6.0 + assert num_tokens.item() == 3 + + def test_loss_with_all_zeros_mask(self): + """Test loss with all-zeros mask.""" + from megatron.bridge.training.mimo_step import loss_func + + output_tensor = torch.tensor([1.0, 2.0, 3.0]) + loss_mask = torch.zeros(3) + + total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) + + assert total_loss.item() == 0.0 + assert num_tokens.item() == 0 + + +class TestGetBatch: + """Test cases for get_batch().""" + + def test_returns_none_for_none_iterator(self): + """Test returns None when iterator is None.""" + from megatron.bridge.training.mimo_step import get_batch + + result = get_batch(None) + assert result is None + + def test_returns_none_on_stop_iteration(self): + """Test returns None when iterator is exhausted.""" + from megatron.bridge.training.mimo_step import get_batch + + empty_iter = iter([]) + result = get_batch(empty_iter) + assert result is None + + def test_returns_batch_from_iterator(self): + """Test returns batch from iterator.""" + from megatron.bridge.training.mimo_step import get_batch + + batch = {"input_ids": torch.tensor([1, 2, 3])} + data_iter = iter([batch]) + + result = get_batch(data_iter) + + assert result is not None + assert "input_ids" in result + + +class TestForwardStep: + """Test cases for forward_step().""" + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") + def test_forward_step_last_stage(self, mock_unwrap): + """Test forward step at last pipeline stage returns loss func.""" + from megatron.bridge.training.mimo_step import forward_step + + # Create mock state + mock_state = MagicMock() + + # Create mock model with role=None (indicates last stage) + mock_model = MagicMock() + mock_model.role = None # role=None means is_last_stage=True + mock_output = torch.tensor([1.0, 2.0]) + mock_loss_mask = torch.ones(2) + mock_model.return_value = (mock_output, mock_loss_mask) + + # unwrap_mimo_model returns the mock model itself + mock_unwrap.return_value = mock_model + + # Create mock iterator + batch = {"input_ids": torch.tensor([1, 2])} + data_iter = iter([batch]) + + output, loss_fn = forward_step(mock_state, data_iter, mock_model) + + # At last stage, should return loss function + assert loss_fn is not None + assert callable(loss_fn) + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") + def test_forward_step_intermediate_stage(self, mock_unwrap): + """Test forward step at intermediate stage returns None for loss func.""" + from megatron.bridge.training.mimo_step import forward_step + + mock_state = MagicMock() + mock_model = MagicMock() + # Configure role to indicate intermediate stage (not last stage) + mock_role = MagicMock() + mock_role.has_language_module = True + mock_role.has_modality_modules = False + mock_role.is_last_stage.return_value = False + mock_role.is_first_stage.return_value = True + mock_model.role = mock_role + mock_model.return_value = (torch.tensor([1.0]), None) + + mock_unwrap.return_value = mock_model + + batch = {"input_ids": torch.tensor([1, 2])} + data_iter = iter([batch]) + + output, loss_fn = forward_step(mock_state, data_iter, mock_model) + + # Intermediate stage should return None for loss_fn + assert loss_fn is None + + @patch("megatron.bridge.training.mimo_step.unwrap_mimo_model") + def test_forward_step_rejects_dict_at_last_stage(self, mock_unwrap): + """Test forward step raises error if dict returned at last stage.""" + from megatron.bridge.training.mimo_step import forward_step + + mock_state = MagicMock() + mock_model = MagicMock() + mock_model.role = None # role=None means is_last_stage=True + # Return dict (incorrect for last stage) + mock_model.return_value = ({"encoder": torch.tensor([1.0])}, None) + + mock_unwrap.return_value = mock_model + + batch = {"input_ids": torch.tensor([1, 2])} + data_iter = iter([batch]) + + with pytest.raises(ValueError, match="Last pipeline stage must return scalar loss"): + forward_step(mock_state, data_iter, mock_model) + + def test_forward_step_uses_global_state_signature(self): + """Test forward step uses 3-arg signature with GlobalState.""" + import inspect + + from megatron.bridge.training.mimo_step import forward_step + + sig = inspect.signature(forward_step) + params = list(sig.parameters.keys()) + + # Should have state as first parameter + assert params[0] == "state" + assert len(params) == 3 diff --git a/tests/unit_tests/training/mimo/test_pretrain_mimo.py b/tests/unit_tests/training/mimo/test_pretrain_mimo.py new file mode 100644 index 0000000000..b324ae59b3 --- /dev/null +++ b/tests/unit_tests/training/mimo/test_pretrain_mimo.py @@ -0,0 +1,177 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +"""Unit tests for MIMO pretrain entrypoint wiring.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _make_cfg(): + cfg = MagicMock() + cfg.train = SimpleNamespace( + rampup_batch_size=None, + global_batch_size=1, + micro_batch_size=1, + decrease_batch_size_if_needed=False, + ) + cfg.data_parallel_size = 1 + return cfg + + +def _make_setup_output(module_to_grid_map): + return SimpleNamespace( + model=MagicMock(), + mimo_infra=SimpleNamespace(module_to_grid_map=module_to_grid_map), + multimodule_communicator=MagicMock(), + multimodule_pg_collection=MagicMock(), + module_to_grid_tuple=[(MagicMock(), MagicMock())], + optimizer=MagicMock(), + schedulers={}, + train_data_iterator=iter([]), + valid_data_iterator=None, + global_state=MagicMock(), + ) + + +@patch( + "megatron.bridge.training.pretrain_mimo.is_current_rank_in_grid", + side_effect=lambda grid: grid.rank_offset <= 4 < (grid.rank_offset + grid.size), +) +@patch("megatron.bridge.training.pretrain_mimo.dist") +def test_set_mimo_random_seeds_calls_model_parallel_cuda_manual_seed(mock_dist, _mock_in_grid): + """_set_mimo_random_seeds should derive TP/PP ranks from grids and call model_parallel_cuda_manual_seed.""" + from megatron.bridge.training.pretrain_mimo import _set_mimo_random_seeds + + mock_dist.get_rank.return_value = 4 # e.g. first rank of vision encoder + + # Build a mock grid: vision ranks [4,8), TP=2, PP=1 + tp_pg = MagicMock() + pp_pg = MagicMock() + mock_dist.get_group_rank.side_effect = lambda pg, rank: {tp_pg: 0, pp_pg: 0}[pg] + + grid = MagicMock() + grid.rank_offset = 4 + grid.size = 4 + grid.get_pg.side_effect = lambda dims: {"tp": tp_pg, "pp": pp_pg}[dims[0]] + + mimo_infra = SimpleNamespace(module_to_grid_map={"vision": grid}) + cfg = SimpleNamespace(rng=SimpleNamespace(seed=42)) + + with patch("megatron.core.tensor_parallel.model_parallel_cuda_manual_seed") as mock_seed: + import torch + + with patch.object(torch.cuda, "device_count", return_value=1): + _set_mimo_random_seeds(cfg, mimo_infra) + + # pp_rank=0, so seed stays 42. tp_rank=0 passed explicitly. + mock_seed.assert_called_once_with(42, tp_rank=0, ep_rank=0, etp_rank=0) + + +@patch( + "megatron.bridge.training.pretrain_mimo.is_current_rank_in_grid", + side_effect=lambda grid: grid.rank_offset <= 2 < (grid.rank_offset + grid.size), +) +@patch("megatron.bridge.training.pretrain_mimo.dist") +def test_set_mimo_random_seeds_offsets_by_pp_rank(mock_dist, _mock_in_grid): + """PP rank > 0 should offset the seed by 100 * pp_rank.""" + from megatron.bridge.training.pretrain_mimo import _set_mimo_random_seeds + + mock_dist.get_rank.return_value = 2 + + tp_pg = MagicMock() + pp_pg = MagicMock() + # tp_rank=1, pp_rank=1 + mock_dist.get_group_rank.side_effect = lambda pg, rank: {tp_pg: 1, pp_pg: 1}[pg] + + grid = MagicMock() + grid.rank_offset = 0 + grid.size = 4 + grid.get_pg.side_effect = lambda dims: {"tp": tp_pg, "pp": pp_pg}[dims[0]] + + mimo_infra = SimpleNamespace(module_to_grid_map={"llm": grid}) + cfg = SimpleNamespace(rng=SimpleNamespace(seed=42)) + + with patch("megatron.core.tensor_parallel.model_parallel_cuda_manual_seed") as mock_seed: + import torch + + with patch.object(torch.cuda, "device_count", return_value=1): + _set_mimo_random_seeds(cfg, mimo_infra) + + # seed = 42 + 100 * 1 = 142, tp_rank=1 + mock_seed.assert_called_once_with(142, tp_rank=1, ep_rank=0, etp_rank=0) + + +@patch("megatron.bridge.training.pretrain_mimo.train_mimo") +@patch("megatron.bridge.training.pretrain_mimo.setup_mimo") +@patch("megatron.bridge.training.pretrain_mimo.dist") +def test_pretrain_mimo_calls_setup_and_train(mock_dist, mock_setup_mimo, mock_train_mimo): + """pretrain_mimo should call setup_mimo then train_mimo.""" + from megatron.bridge.training.pretrain_mimo import pretrain_mimo + + cfg = _make_cfg() + + mock_dist.get_rank.return_value = 0 + setup_output = _make_setup_output(module_to_grid_map={"language": MagicMock()}) + mock_setup_mimo.return_value = setup_output + + with ( + patch("megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR", None), + patch("megatron.core.num_microbatches_calculator.init_num_microbatches_calculator"), + ): + pretrain_mimo( + cfg=cfg, + mimo_provider=MagicMock(), + forward_step_func=MagicMock(), + build_data_iterators_fn=MagicMock(), + global_state=MagicMock(), + ) + + mock_setup_mimo.assert_called_once() + mock_train_mimo.assert_called_once() + + +@patch("megatron.bridge.training.pretrain_mimo.unwrap_mimo_model") +@patch("megatron.bridge.training.pretrain_mimo.get_model_config") +@patch("megatron.bridge.training.pretrain_mimo.dist") +def test_setup_mimo_asserts_when_constructor_fields_missing(mock_dist, mock_get_model_config, mock_unwrap_mimo_model): + """setup_mimo guardrail should fail when module_to_grid_map is missing at construction.""" + from megatron.bridge.training.pretrain_mimo import setup_mimo + + cfg = _make_cfg() + mock_dist.get_rank.return_value = 0 + mock_dist.get_world_size.return_value = 8 + + # Model with missing module_to_grid_map + unwrapped_model = MagicMock() + unwrapped_model.mimo_config = SimpleNamespace(module_to_grid_map=None) + mock_unwrap_mimo_model.return_value = unwrapped_model + + mock_model_config = MagicMock() + mock_model_config.pipeline_dtype = None + mock_model_config.bf16 = True + mock_get_model_config.return_value = mock_model_config + + # Provider that returns infra with an active grid map + mock_provider = MagicMock() + mock_infra = MagicMock() + mock_infra.module_to_grid_map = {"language": MagicMock()} + mock_infra.topology = {"language": []} + mock_infra.module_output_ndim = {"language": 3} + mock_provider.build_infra.return_value = mock_infra + mock_provider.provide_distributed_model.return_value = [MagicMock()] + + with ( + patch("megatron.bridge.training.pretrain_mimo.validate_no_stub_ranks"), + patch("megatron.bridge.training.pretrain_mimo._set_mimo_random_seeds"), + patch("megatron.bridge.training.pretrain_mimo.build_pg_collection_for_schedule"), + patch("megatron.bridge.training.pretrain_mimo.get_module_to_grid_tuple"), + patch("megatron.bridge.training.pretrain_mimo.MultiModulePipelineCommunicator"), + ): + with pytest.raises(AssertionError, match="module_to_grid_map must be set"): + setup_mimo( + cfg=cfg, + mimo_provider=mock_provider, + build_optimizer=True, + global_state=MagicMock(), + )