Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/megatron/bridge/data/mimo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""MIMO multi-encoder data loading utilities."""

from megatron.bridge.data.mimo.dataset import MimoDataset
from megatron.bridge.data.mimo.collate import mimo_collate_fn
from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info
from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders

# Providers
from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider
from megatron.bridge.data.mimo.collate import mimo_collate_fn
from megatron.bridge.data.mimo.dataset import MimoDataset
from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info, get_mimo_sampling_info, slice_batch_for_mimo
from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider
from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders
from megatron.bridge.data.mimo.mock_provider import MockMimoProvider


__all__ = [
# Core
"MimoDataset",
Expand All @@ -21,5 +21,7 @@
"MockMimoProvider",
# Utilities
"get_mimo_dp_info",
"get_mimo_sampling_info",
"slice_batch_for_mimo",
"build_mimo_data_loaders",
]
176 changes: 138 additions & 38 deletions src/megatron/bridge/data/mimo/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Any, Dict, Tuple

import torch
import torch.distributed as dist
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

Expand All @@ -15,61 +16,160 @@
from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


def _find_rank_module(
grids: Dict[str, "HyperCommGrid"],
) -> Tuple["HyperCommGrid | None", "str | None"]:
"""Find which module grid the current rank belongs to."""
current_rank = dist.get_rank()
for module_name, grid in grids.items():
if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size):
return grid, module_name
return None, None


def _needs_data_for_module(grid: "HyperCommGrid", module_name: str) -> bool:
"""Determine if the current rank needs to load data for the given module.

LLM: first and last PP stage need data (input_ids and labels respectively).
Encoders: only the first PP stage needs raw modality inputs.
"""
pp_group = grid.get_pg(["pp"])
pp_rank = pp_group.rank()
pp_size = pp_group.size()
if module_name == MIMO_LANGUAGE_MODULE_KEY:
return (pp_rank == 0) or (pp_rank == pp_size - 1)
return pp_rank == 0


def get_mimo_dp_info(
mimo_cfg: "MimoParallelismConfig",
grids: Dict[str, "HyperCommGrid"],
) -> Tuple[int, int, bool, str]:
"""Get DP rank, size, data-loading responsibility, and loader module for MIMO.
"""Get **module-local** DP rank, size, data-loading flag, and module name.

Determines which module's DP settings to use for data loading based on
current rank's participation in heterogeneous deployment.
Returns the DP settings for the module that the current rank participates
in. These are used by :func:`slice_batch_for_mimo` to sub-shard a global
micro-batch into per-module DP shards.

In heterogeneous mode, each rank uses its own module's DP settings.
.. note::
Do **not** use these values to construct a ``DistributedSampler``.
For sampler construction use :func:`get_mimo_sampling_info` instead,
which returns settings that keep all data-loading ranks synchronised
on the same sample order.

Args:
mimo_cfg: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping from build_hypercomm_grids().

Returns:
Tuple of (dp_rank, dp_size, needs_data, loader_module):
- dp_rank: This rank's position in DP group.
- dp_size: Size of DP group for data sharding.
- needs_data: Whether this rank needs to load data (first/last PP stage).
- loader_module: Which module's DP settings are being used.

Example:
>>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids
>>> grids = build_hypercomm_grids(mimo_cfg)
>>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids)
>>> if needs_data:
... # Build data loader with dp_rank and dp_size
... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank)
Tuple of (dp_rank, dp_size, needs_data, loader_module).
"""
current_rank = dist.get_rank()

# Heterogeneous: find which module this rank belongs to
my_grid = None
my_module = None
for module_name, grid in grids.items():
if grid.rank_offset <= current_rank < (grid.rank_offset + grid.size):
my_grid = grid
my_module = module_name
break

my_grid, my_module = _find_rank_module(grids)
if my_grid is None or my_module is None:
# Rank doesn't participate in any module
return 0, 1, False, MIMO_LANGUAGE_MODULE_KEY

dp_rank = my_grid.get_pg(["dp"]).rank()
dp_size = my_grid.get_pg(["dp"]).size()
needs_data = _needs_data_for_module(my_grid, my_module)
return dp_rank, dp_size, needs_data, my_module

pp_group = my_grid.get_pg(["pp"])
pp_rank = pp_group.rank()
pp_size = pp_group.size()

if my_module == MIMO_LANGUAGE_MODULE_KEY:
needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1)
else:
needs_data = pp_rank == 0
def get_mimo_sampling_info(
mimo_cfg: "MimoParallelismConfig",
grids: Dict[str, "HyperCommGrid"],
) -> Tuple[int, int, bool]:
"""Get sampler DP rank, size, and data-loading flag for MIMO.

return dp_rank, dp_size, needs_data, my_module
In heterogeneous MIMO, modules may have different DP sizes. The data
loader must give every data-loading rank the **same global micro-batch**
so that :func:`slice_batch_for_mimo` (called in the forward step) can
sub-shard it consistently with the :class:`BridgeCommunicator` fan-in /
fan-out routing.

This function therefore returns ``dp_size=1, dp_rank=0`` for all ranks,
disabling DP sharding at the sampler level. Per-module DP sharding is
deferred to :func:`slice_batch_for_mimo`.

Args:
mimo_cfg: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping.

Returns:
Tuple of (sampler_dp_rank, sampler_dp_size, needs_data).
"""
my_grid, my_module = _find_rank_module(grids)
if my_grid is None or my_module is None:
return 0, 1, False

needs_data = _needs_data_for_module(my_grid, my_module)
# All data-loading ranks use the same sampler settings so they load
# identical global micro-batches. Module-local DP slicing happens later
# in forward_step via slice_batch_for_mimo.
return 0, 1, needs_data


def slice_batch_for_mimo(
batch: Dict[str, Any],
dp_rank: int,
dp_size: int,
) -> Dict[str, Any]:
"""Slice a global micro-batch for this rank's module-local DP shard.

All data-loading ranks receive the same global micro-batch (the sampler
uses ``dp_size=1``). This function contiguously slices it so that each
module-local DP replica processes the correct subset. The slicing is
contiguous to match the :class:`BridgeCommunicator`'s batch-dimension
split / concatenate logic for fan-out and fan-in routing.

Handles nested dicts (e.g. ``modality_inputs``) by recursing.

Args:
batch: Global batch dictionary with tensors of shape [global_batch, ...].
May contain nested dicts (e.g. modality_inputs → encoder → kwargs).
dp_rank: This rank's position in its **module-local** DP group.
dp_size: Size of the module-local DP group.

Returns:
Dict with tensors sliced to shape [global_batch // dp_size, ...].

Example:
>>> global_batch = {'tokens': torch.randn(12, 2048)}
>>> local_batch = slice_batch_for_mimo(global_batch, dp_rank=1, dp_size=3)
>>> local_batch['tokens'].shape # torch.Size([4, 2048])
"""
if dp_size == 1:
return batch

sliced = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
# Slice along batch dimension (dim=0)
batch_size = value.size(0)
if batch_size % dp_size != 0:
raise ValueError(
f"Batch size {batch_size} for key '{key}' is not divisible "
f"by DP size {dp_size}. Ensure micro_batch_size is divisible "
f"by every module's data_parallel_size."
)
local_batch_size = batch_size // dp_size
start_idx = dp_rank * local_batch_size
end_idx = start_idx + local_batch_size
sliced[key] = value[start_idx:end_idx]
elif isinstance(value, dict):
# Recurse into nested dicts (e.g. modality_inputs)
sliced[key] = slice_batch_for_mimo(value, dp_rank, dp_size)
elif isinstance(value, list) and len(value) > 0:
list_len = len(value)
if list_len % dp_size == 0:
local_len = list_len // dp_size
start_idx = dp_rank * local_len
end_idx = start_idx + local_len
sliced[key] = value[start_idx:end_idx]
else:
# Keep as-is if not evenly divisible (global metadata)
sliced[key] = value
else:
# Keep non-tensor, non-list values as-is
sliced[key] = value

return sliced
72 changes: 44 additions & 28 deletions src/megatron/bridge/data/mimo/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

import torch
from torch.utils.data import DataLoader

from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info
from megatron.bridge.data.mimo.dp_utils import get_mimo_sampling_info
from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider
from megatron.bridge.utils.common_utils import print_rank_0


if TYPE_CHECKING:
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.state import TrainState
Expand All @@ -25,12 +26,14 @@ def build_mimo_data_loaders(
valid_samples: int,
test_samples: int,
) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
"""Build MIMO data loaders with per-module DP settings.

Creates data loaders with DP-aware sampling based on the MIMO parallelism
configuration. Only ranks that need data (first/last PP stage) will get
non-None loaders.

"""Build MIMO data loaders with globally consistent sampling.

All data-loading ranks receive identical global micro-batches (the sampler
uses dp_size=1). Per-module DP sub-sharding is deferred to
``slice_batch_for_mimo`` in the forward step, ensuring consistency with
the BridgeCommunicator's fan-in/fan-out routing for asymmetric DP configs.
Only ranks that need data (first/last PP stage) will get non-None loaders.

Args:
cfg: Configuration container with MimoModelProvider as cfg.model.
train_state: Current training state.
Expand All @@ -39,14 +42,14 @@ def build_mimo_data_loaders(
train_samples: Number of training samples.
valid_samples: Number of validation samples.
test_samples: Number of test samples.

Returns:
Tuple of (train_loader, valid_loader, test_loader).
Returns (None, None, None) if this rank doesn't need data.

Raises:
ValueError: If cfg.model is not MimoModelProvider or mimo_parallelism_config is None.

Example:
>>> from megatron.bridge.data.mimo import MockMimoProvider, build_mimo_data_loaders
>>> provider = MockMimoProvider(
Expand All @@ -62,27 +65,36 @@ def build_mimo_data_loaders(
... )
"""
from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider

if not isinstance(cfg.model, MimoModelProvider):
raise ValueError("cfg.model must be MimoModelProvider for MIMO data loading.")

if cfg.model.mimo_parallelism_config is None:
raise ValueError("mimo_parallelism_config must be set for MIMO data loading.")

if cfg.model._grids is None:
raise ValueError(
"MimoModelProvider._grids is None. "
"Ensure build_model() is called before building data loaders."
"MimoModelProvider._grids is None. Ensure build_model() is called before building data loaders."
)

# Validate that micro_batch_size is divisible by every module's DP size.
# slice_batch_for_mimo divides the micro-batch contiguously by the module's
# DP size in forward_step; a non-divisible MBS would leave a remainder.
micro_batch_size = cfg.train.micro_batch_size
for mod_name, mod_cfg in cfg.model.mimo_parallelism_config.module_parallelisms.items():
dp = mod_cfg.data_parallel_size
assert micro_batch_size % dp == 0, (
f"micro_batch_size ({micro_batch_size}) must be divisible by "
f"data_parallel_size ({dp}) of module '{mod_name}'. "
f"slice_batch_for_mimo requires an evenly divisible micro-batch."
)

print_rank_0("> building MIMO train, validation, and test datasets ...")

# Use cached grids from build_model()
grids = cfg.model._grids

dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(
cfg.model.mimo_parallelism_config, grids
)

sampler_dp_rank, sampler_dp_size, needs_data = get_mimo_sampling_info(cfg.model.mimo_parallelism_config, grids)

if not needs_data:
return None, None, None
Expand All @@ -96,21 +108,25 @@ def build_mimo_data_loaders(
)
train_ds, valid_ds, test_ds = mimo_provider.build_datasets(context)

print_rank_0(f" Built datasets: train={len(train_ds) if train_ds else 0}, "
f"valid={len(valid_ds) if valid_ds else 0}, "
f"test={len(test_ds) if test_ds else 0}")
print_rank_0(
f" Built datasets: train={len(train_ds) if train_ds else 0}, "
f"valid={len(valid_ds) if valid_ds else 0}, "
f"test={len(test_ds) if test_ds else 0}"
)

# Build data loaders with DP-aware sampling
# Build data loaders with globally consistent sampling.
# sampler_dp_size=1 so all data-loading ranks see the same batches.
# Per-module DP sub-sharding is done later by slice_batch_for_mimo.
collate_fn = mimo_provider.get_collate_fn()
micro_batch_size = cfg.train.micro_batch_size

def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]:
if dataset is None:
return None
sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=dp_size,
rank=dp_rank,
num_replicas=sampler_dp_size,
rank=sampler_dp_rank,
shuffle=shuffle,
)
return DataLoader(
Expand All @@ -126,5 +142,5 @@ def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]:
train_loader = _make_loader(train_ds, shuffle=True)
valid_loader = _make_loader(valid_ds, shuffle=False)
test_loader = _make_loader(test_ds, shuffle=False)

return train_loader, valid_loader, test_loader
1 change: 1 addition & 0 deletions src/megatron/bridge/models/mimo/mimo_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_hypercomm_grids(
_ = grid.create_pg(["tp", "pp"])
_ = grid.create_pg(["tp", "ep", "pp"])
_ = grid.create_pg(["dp", "ep"])
_ = grid.create_pg(["tp", "cp", "ep", "pp", "dp"])

grids[module_name] = grid

Expand Down
Loading
Loading