diff --git a/src/megatron/bridge/data/mimo/__init__.py b/src/megatron/bridge/data/mimo/__init__.py index 045d87152e..8713d2c68c 100644 --- a/src/megatron/bridge/data/mimo/__init__.py +++ b/src/megatron/bridge/data/mimo/__init__.py @@ -1,16 +1,16 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """MIMO multi-encoder data loading utilities.""" -from megatron.bridge.data.mimo.dataset import MimoDataset -from megatron.bridge.data.mimo.collate import mimo_collate_fn -from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info -from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders - # Providers from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider +from megatron.bridge.data.mimo.collate import mimo_collate_fn +from megatron.bridge.data.mimo.dataset import MimoDataset +from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info, get_mimo_sampling_info, slice_batch_for_mimo from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider +from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders from megatron.bridge.data.mimo.mock_provider import MockMimoProvider + __all__ = [ # Core "MimoDataset", @@ -21,5 +21,7 @@ "MockMimoProvider", # Utilities "get_mimo_dp_info", + "get_mimo_sampling_info", + "slice_batch_for_mimo", "build_mimo_data_loaders", ] diff --git a/src/megatron/bridge/data/mimo/dp_utils.py b/src/megatron/bridge/data/mimo/dp_utils.py index ab327bc27b..eb3dd1b8ea 100644 --- a/src/megatron/bridge/data/mimo/dp_utils.py +++ b/src/megatron/bridge/data/mimo/dp_utils.py @@ -13,64 +13,96 @@ from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig +def _find_rank_module( + grids: Dict[str, "HyperCommGrid"], +) -> Tuple["HyperCommGrid | None", "str | None"]: + """Find which module grid the current rank belongs to.""" + current_rank = dist.get_rank() + for module_name, grid in grids.items(): + if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): + return grid, module_name + return None, None + + +def _needs_data_for_module(grid: "HyperCommGrid", module_name: str) -> bool: + """Determine if the current rank needs to load data for the given module. + + LLM: first and last PP stage need data (input_ids and labels respectively). + Encoders: only the first PP stage needs raw modality inputs. + """ + pp_group = grid.get_pg(["pp"]) + pp_rank = pp_group.rank() + pp_size = pp_group.size() + if module_name == "llm": + return (pp_rank == 0) or (pp_rank == pp_size - 1) + return pp_rank == 0 + + def get_mimo_dp_info( mimo_cfg: "MimoParallelismConfig", grids: Dict[str, "HyperCommGrid"], ) -> Tuple[int, int, bool, str]: - """Get DP rank, size, data-loading responsibility, and loader module for MIMO. - - Determines which module's DP settings to use for data loading based on - current rank's participation in heterogeneous deployment. - - In heterogeneous mode, each rank uses its own module's DP settings. - + """Get **module-local** DP rank, size, data-loading flag, and module name. + + Returns the DP settings for the module that the current rank participates + in. These are used by :func:`slice_batch_for_mimo` to sub-shard a global + micro-batch into per-module DP shards. + + .. note:: + Do **not** use these values to construct a ``DistributedSampler``. + For sampler construction use :func:`get_mimo_sampling_info` instead, + which returns settings that keep all data-loading ranks synchronised + on the same sample order. + Args: mimo_cfg: MIMO parallelism configuration. grids: Module name to HyperCommGrid mapping from build_hypercomm_grids(). - + Returns: - Tuple of (dp_rank, dp_size, needs_data, loader_module): - - dp_rank: This rank's position in DP group. - - dp_size: Size of DP group for data sharding. - - needs_data: Whether this rank needs to load data (first/last PP stage). - - loader_module: Which module's DP settings are being used. - - Example: - >>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids - >>> grids = build_hypercomm_grids(mimo_cfg) - >>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) - >>> if needs_data: - ... # Build data loader with dp_rank and dp_size - ... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank) + Tuple of (dp_rank, dp_size, needs_data, loader_module). """ - current_rank = dist.get_rank() - - # Heterogeneous: find which module this rank belongs to - my_grid = None - my_module = None - for module_name, grid in grids.items(): - if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): - my_grid = grid - my_module = module_name - break - + my_grid, my_module = _find_rank_module(grids) if my_grid is None or my_module is None: - # Rank doesn't participate in any module return 0, 1, False, "llm" dp_rank = my_grid.get_pg(["dp"]).rank() dp_size = my_grid.get_pg(["dp"]).size() + needs_data = _needs_data_for_module(my_grid, my_module) + return dp_rank, dp_size, needs_data, my_module - pp_group = my_grid.get_pg(["pp"]) - pp_rank = pp_group.rank() - pp_size = pp_group.size() - if my_module == "llm": - needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1) - else: - needs_data = pp_rank == 0 +def get_mimo_sampling_info( + mimo_cfg: "MimoParallelismConfig", + grids: Dict[str, "HyperCommGrid"], +) -> Tuple[int, int, bool]: + """Get sampler DP rank, size, and data-loading flag for MIMO. - return dp_rank, dp_size, needs_data, my_module + In heterogeneous MIMO, modules may have different DP sizes. The data + loader must give every data-loading rank the **same global micro-batch** + so that :func:`slice_batch_for_mimo` (called in the forward step) can + sub-shard it consistently with the :class:`BridgeCommunicator` fan-in / + fan-out routing. + + This function therefore returns ``dp_size=1, dp_rank=0`` for all ranks, + disabling DP sharding at the sampler level. Per-module DP sharding is + deferred to :func:`slice_batch_for_mimo`. + + Args: + mimo_cfg: MIMO parallelism configuration. + grids: Module name to HyperCommGrid mapping. + + Returns: + Tuple of (sampler_dp_rank, sampler_dp_size, needs_data). + """ + my_grid, my_module = _find_rank_module(grids) + if my_grid is None or my_module is None: + return 0, 1, False + + needs_data = _needs_data_for_module(my_grid, my_module) + # All data-loading ranks use the same sampler settings so they load + # identical global micro-batches. Module-local DP slicing happens later + # in forward_step via slice_batch_for_mimo. + return 0, 1, needs_data def slice_batch_for_mimo( @@ -78,31 +110,33 @@ def slice_batch_for_mimo( dp_rank: int, dp_size: int, ) -> Dict[str, Any]: - """Slice a global batch for this rank's DP shard. - - Takes a global batch (same data on all ranks) and returns the portion - that this rank should process based on its DP rank and size. - - Used by both training and evaluation to ensure consistent data sharding - across heterogeneous MIMO modules. - + """Slice a global micro-batch for this rank's module-local DP shard. + + All data-loading ranks receive the same global micro-batch (the sampler + uses ``dp_size=1``). This function contiguously slices it so that each + module-local DP replica processes the correct subset. The slicing is + contiguous to match the :class:`BridgeCommunicator`'s batch-dimension + split / concatenate logic for fan-out and fan-in routing. + + Handles nested dicts (e.g. ``modality_inputs``) by recursing. + Args: batch: Global batch dictionary with tensors of shape [global_batch, ...]. - dp_rank: This rank's position in its DP group. - dp_size: Total size of the DP group. - + May contain nested dicts (e.g. modality_inputs → encoder → kwargs). + dp_rank: This rank's position in its **module-local** DP group. + dp_size: Size of the module-local DP group. + Returns: - Dict[str, Any]: Sliced batch with tensors of shape [local_batch, ...]. - + Dict with tensors sliced to shape [global_batch // dp_size, ...]. + Example: - >>> # Global batch of 16 samples, DP size 4, this is DP rank 1 - >>> global_batch = {'tokens': torch.randn(16, 2048)} - >>> local_batch = slice_batch_for_mimo(global_batch, dp_rank=1, dp_size=4) + >>> global_batch = {'tokens': torch.randn(12, 2048)} + >>> local_batch = slice_batch_for_mimo(global_batch, dp_rank=1, dp_size=3) >>> local_batch['tokens'].shape # torch.Size([4, 2048]) """ if dp_size == 1: return batch - + sliced = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): @@ -111,14 +145,17 @@ def slice_batch_for_mimo( if batch_size % dp_size != 0: raise ValueError( f"Batch size {batch_size} for key '{key}' is not divisible " - f"by DP size {dp_size}" + f"by DP size {dp_size}. Ensure micro_batch_size is divisible " + f"by every module's data_parallel_size." ) local_batch_size = batch_size // dp_size start_idx = dp_rank * local_batch_size end_idx = start_idx + local_batch_size sliced[key] = value[start_idx:end_idx] + elif isinstance(value, dict): + # Recurse into nested dicts (e.g. modality_inputs) + sliced[key] = slice_batch_for_mimo(value, dp_rank, dp_size) elif isinstance(value, list) and len(value) > 0: - # Handle list values (e.g., metadata lists) list_len = len(value) if list_len % dp_size == 0: local_len = list_len // dp_size @@ -131,5 +168,5 @@ def slice_batch_for_mimo( else: # Keep non-tensor, non-list values as-is sliced[key] = value - + return sliced diff --git a/src/megatron/bridge/data/mimo/loaders.py b/src/megatron/bridge/data/mimo/loaders.py index 576f114505..5e61ca2d06 100644 --- a/src/megatron/bridge/data/mimo/loaders.py +++ b/src/megatron/bridge/data/mimo/loaders.py @@ -3,15 +3,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch from torch.utils.data import DataLoader -from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info +from megatron.bridge.data.mimo.dp_utils import get_mimo_sampling_info from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider from megatron.bridge.utils.common_utils import print_rank_0 + if TYPE_CHECKING: from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.state import TrainState @@ -25,12 +26,14 @@ def build_mimo_data_loaders( valid_samples: int, test_samples: int, ) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]: - """Build MIMO data loaders with per-module DP settings. - - Creates data loaders with DP-aware sampling based on the MIMO parallelism - configuration. Only ranks that need data (first/last PP stage) will get - non-None loaders. - + """Build MIMO data loaders with globally consistent sampling. + + All data-loading ranks receive identical global micro-batches (the sampler + uses dp_size=1). Per-module DP sub-sharding is deferred to + ``slice_batch_for_mimo`` in the forward step, ensuring consistency with + the BridgeCommunicator's fan-in/fan-out routing for asymmetric DP configs. + Only ranks that need data (first/last PP stage) will get non-None loaders. + Args: cfg: Configuration container with MimoModelProvider as cfg.model. train_state: Current training state. @@ -39,14 +42,14 @@ def build_mimo_data_loaders( train_samples: Number of training samples. valid_samples: Number of validation samples. test_samples: Number of test samples. - + Returns: Tuple of (train_loader, valid_loader, test_loader). Returns (None, None, None) if this rank doesn't need data. - + Raises: ValueError: If cfg.model is not MimoModelProvider or mimo_parallelism_config is None. - + Example: >>> from megatron.bridge.data.mimo import MockMimoProvider, build_mimo_data_loaders >>> provider = MockMimoProvider( @@ -62,27 +65,36 @@ def build_mimo_data_loaders( ... ) """ from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider - + if not isinstance(cfg.model, MimoModelProvider): raise ValueError("cfg.model must be MimoModelProvider for MIMO data loading.") - + if cfg.model.mimo_parallelism_config is None: raise ValueError("mimo_parallelism_config must be set for MIMO data loading.") - + if cfg.model._grids is None: raise ValueError( - "MimoModelProvider._grids is None. " - "Ensure build_model() is called before building data loaders." + "MimoModelProvider._grids is None. Ensure build_model() is called before building data loaders." + ) + + # Validate that micro_batch_size is divisible by every module's DP size. + # slice_batch_for_mimo divides the micro-batch contiguously by the module's + # DP size in forward_step; a non-divisible MBS would leave a remainder. + micro_batch_size = cfg.train.micro_batch_size + for mod_name, mod_cfg in cfg.model.mimo_parallelism_config.module_parallelisms.items(): + dp = mod_cfg.data_parallel_size + assert micro_batch_size % dp == 0, ( + f"micro_batch_size ({micro_batch_size}) must be divisible by " + f"data_parallel_size ({dp}) of module '{mod_name}'. " + f"slice_batch_for_mimo requires an evenly divisible micro-batch." ) print_rank_0("> building MIMO train, validation, and test datasets ...") # Use cached grids from build_model() grids = cfg.model._grids - - dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info( - cfg.model.mimo_parallelism_config, grids - ) + + sampler_dp_rank, sampler_dp_size, needs_data = get_mimo_sampling_info(cfg.model.mimo_parallelism_config, grids) if not needs_data: return None, None, None @@ -96,21 +108,25 @@ def build_mimo_data_loaders( ) train_ds, valid_ds, test_ds = mimo_provider.build_datasets(context) - print_rank_0(f" Built datasets: train={len(train_ds) if train_ds else 0}, " - f"valid={len(valid_ds) if valid_ds else 0}, " - f"test={len(test_ds) if test_ds else 0}") + print_rank_0( + f" Built datasets: train={len(train_ds) if train_ds else 0}, " + f"valid={len(valid_ds) if valid_ds else 0}, " + f"test={len(test_ds) if test_ds else 0}" + ) - # Build data loaders with DP-aware sampling + # Build data loaders with globally consistent sampling. + # sampler_dp_size=1 so all data-loading ranks see the same batches. + # Per-module DP sub-sharding is done later by slice_batch_for_mimo. collate_fn = mimo_provider.get_collate_fn() micro_batch_size = cfg.train.micro_batch_size - + def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: if dataset is None: return None sampler = torch.utils.data.DistributedSampler( dataset, - num_replicas=dp_size, - rank=dp_rank, + num_replicas=sampler_dp_size, + rank=sampler_dp_rank, shuffle=shuffle, ) return DataLoader( @@ -126,5 +142,5 @@ def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]: train_loader = _make_loader(train_ds, shuffle=True) valid_loader = _make_loader(valid_ds, shuffle=False) test_loader = _make_loader(test_ds, shuffle=False) - + return train_loader, valid_loader, test_loader diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py index c3e3323ae0..9671dc78ac 100644 --- a/src/megatron/bridge/training/mimo_step.py +++ b/src/megatron/bridge/training/mimo_step.py @@ -13,20 +13,41 @@ import logging from functools import partial -from typing import TYPE_CHECKING, Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple import torch from megatron.core.models.mimo import MimoModel +from megatron.bridge.data.mimo.dp_utils import slice_batch_for_mimo from megatron.bridge.training.mimo_parallel_utils import unwrap_mimo_model from megatron.bridge.training.state import GlobalState -if TYPE_CHECKING: - pass +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +def _get_module_dp_info( + mimo_model: MimoModel, +) -> Tuple[int, int]: + """Get module-local DP rank and size for the current rank. + + Used to slice the global micro-batch via :func:`slice_batch_for_mimo`. + Returns (0, 1) when grids are not configured (colocated mode). + """ + grids = getattr(mimo_model.mimo_config, "module_to_grid_map", None) + if not grids: + return 0, 1 + + import torch.distributed as _dist + + current_rank = _dist.get_rank() + for _name, grid in grids.items(): + if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size): + dp_rank = grid.get_pg(["dp"]).rank() + dp_size = grid.get_pg(["dp"]).size() + return dp_rank, dp_size + + return 0, 1 def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor) -> Tuple: @@ -150,6 +171,12 @@ def forward_step( "get_batch returned None at a stage that requires data. " "This indicates a data-loading or parallelism misconfiguration." ) + # Slice the global micro-batch for this module's DP shard. + # All data-loading ranks receive identical batches (sampler dp_size=1). + # slice_batch_for_mimo contiguously sub-shards to match the + # BridgeCommunicator's fan-in/fan-out batch-dimension routing. + dp_rank, dp_size = _get_module_dp_info(mimo_model) + data_batch = slice_batch_for_mimo(data_batch, dp_rank, dp_size) else: # Non-data stages consume hidden states from pipeline input tensors. data_batch = {