Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/megatron/bridge/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def build_train_valid_test_data_loaders(
exit_signal = cfg.train.exit_signal

def worker_init_fn(_):
"""
Install the distributed exit signal handler in a dataloader worker process.

Parameters:
_ (int): Worker process index (unused).
"""
DistributedSignalHandler(exit_signal).__enter__()

maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None
Expand Down Expand Up @@ -322,18 +328,19 @@ def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider: Callable,
dp_group: torch.distributed.ProcessGroup,
) -> tuple[Optional[RerunDataIterator], Optional[RerunDataIterator], Optional[RerunDataIterator]]:
"""Build train, validation, and test data iterators.

Builds the data loaders first, then wraps them in appropriate iterators
(e.g., RerunDataIterator, cyclic_iter) based on the configuration.

Args:
cfg: The main configuration container.
train_state: The current training state.
build_train_valid_test_datasets_provider: A function to build the datasets.

"""
Build iterators for training, validation, and testing data.

Wraps dataloaders produced by build_train_valid_test_data_loaders into iterator objects according to the dataset's dataloader_type; uses a "cyclic" iterator for validation when the dataset is a GPTDatasetConfig. For "external" dataloaders, preserves list or single dataloader semantics.

Parameters:
cfg (ConfigContainer): Configuration that determines dataloader types and dataset settings.
train_state (TrainState): Current training state used when constructing dataloaders.
build_train_valid_test_datasets_provider (Callable): Provider used to build datasets for the dataloaders.
dp_group (torch.distributed.ProcessGroup): Data-parallel process group used when building dataloaders.

Returns:
A tuple (train_data_iterator, valid_data_iterator, test_data_iterator).
tuple: (train_data_iterator, valid_data_iterator, test_data_iterator) where each element is a RerunDataIterator or a list of RerunDataIterator for "external" dataloaders, or `None` if the corresponding dataloader was not created.
"""

# Build loaders.
Expand Down
20 changes: 9 additions & 11 deletions src/megatron/bridge/data/mimo/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,23 @@ class MimoDatasetProvider(DatasetProvider):
def build_datasets(
self, context: DatasetBuildContext
) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]:
"""Build train, validation, and test datasets.
"""
Construct the provider's train, validation, and test datasets.

Parameters:
context (DatasetBuildContext): Build context containing sample counts used during dataset construction.

Args:
context: Build context with sample counts.

Returns:
Tuple of (train_dataset, valid_dataset, test_dataset).
Any element can be None if not needed.
Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: A 3-tuple (train_dataset, valid_dataset, test_dataset); any element may be `None` if that split is not produced.
"""
...

@abstractmethod
def get_collate_fn(self) -> Callable:
"""Return the collate function for batching.

The collate function should handle the modality_inputs dict
and batch them appropriately for the model.
"""
Provide the callable used to collate a list of samples into a batched `modality_inputs` dictionary.

Returns:
Callable that takes a list of samples and returns a batch dict.
Callable[[List[Any]], Dict[str, Any]]: A callable that accepts a list of samples and returns a dictionary mapping modality keys to their batched tensors/structures.
"""
...
57 changes: 21 additions & 36 deletions src/megatron/bridge/data/mimo/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,28 @@ def mimo_collate_fn(
batch: List[Dict[str, Any]],
modality_names: List[str],
) -> Dict[str, Any]:
"""Collate function for MIMO datasets.

Stacks batch items and organizes modality inputs into a structure
suitable for MIMO model forward pass.

Args:
batch: List of examples from MimoDataset, each containing:
- input_ids: Token IDs with placeholder tokens
- labels: Labels for causal LM training
- attention_mask: Attention mask
- position_ids: Position indices
- modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
modality_names: List of modality names to collate.

"""
Collate a list of MIMO dataset examples into a batched dict suitable for model input.

Parameters:
batch: List of example dictionaries. Each example is expected to contain:
- input_ids: Tensor of token IDs.
- labels: Tensor of target token IDs.
- attention_mask: Tensor attention mask aligning with tokens.
- position_ids: Tensor of position indices for tokens.
- loss_mask: Tensor indicating per-token loss contribution.
- modality_inputs: Dict[str, Dict[str, Any]] mapping modality names to modality-specific preprocessed inputs.
modality_names: List of modality names to gather from each example's `modality_inputs`.

Returns:
Dict containing:
- input_ids: (batch, seq) stacked token IDs
- labels: (batch, seq) stacked labels
- attention_mask: (batch, seq) attention mask
- position_ids: (batch, seq) position indices
- modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors
Each modality's tensors are stacked along batch dimension.

Example:
>>> batch = [
... {
... "input_ids": torch.tensor([32000, 1, 2, 3]),
... "labels": torch.tensor([32000, 1, 2, 3]),
... "attention_mask": torch.ones(4),
... "position_ids": torch.arange(4),
... "modality_inputs": {
... "vision": {"pixel_values": torch.randn(3, 224, 224)},
... },
... },
... # ... more examples
... ]
>>> collated = mimo_collate_fn(batch, modality_names=["vision"])
A dictionary with the following entries:
- input_ids: Tensor shaped (batch, seq) of stacked input IDs.
- labels: Tensor shaped (batch, seq) of stacked labels.
- loss_mask: Tensor shaped (batch, seq) of stacked loss masks.
- attention_mask: Tensor shaped (batch, seq) of stacked attention masks.
- position_ids: Tensor shaped (batch, seq) of stacked position indices.
- modality_inputs: Dict mapping each present modality name to a dict of batched modality values.
For each modality key, tensor values are stacked along the batch dimension when shapes permit; values that cannot be stacked or non-tensor values are returned as lists of per-example entries.
"""
if not batch:
return {}
Expand Down
25 changes: 15 additions & 10 deletions src/megatron/bridge/data/mimo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,22 @@ def __len__(self) -> int:
return self._size

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get a single example with preprocessed modality inputs.

"""
Retrieve and preprocess a single dataset example, producing tokenized input IDs with modality placeholder tokens, next-token prediction labels, masking tensors, and per-modality processed inputs.

The returned dictionary contains:
- input_ids: 1D tensor of token IDs of length `seq_length`, with modality placeholder token IDs prepended and padded/truncated to fit.
- labels: 1D tensor of next-token targets where labels[i] == input_ids[i+1], the final position is set to `-100`, and any position corresponding to padding or a modality placeholder is set to `-100` so it is ignored by loss.
- loss_mask: 1D float tensor with `1.0` for positions contributing to the loss and `0.0` for positions that should be ignored (padding, modality placeholders, and the final position).
- attention_mask: 1D tensor indicating non-padded token positions.
- position_ids: 1D tensor of position indices (0..seq_length-1).
- modality_inputs: dict mapping modality name to the processor outputs for that modality (tensors with batch dim removed where applicable).

Parameters:
idx (int): Index of the example to retrieve.

Returns:
Dict containing:
- input_ids: Tokenized text with placeholder tokens
- labels: Shifted input_ids for next-token prediction (-100 for masked positions)
- loss_mask: Float mask (0.0 for padding/image placeholder targets, 1.0 otherwise)
- attention_mask: Attention mask
- position_ids: Position indices
- modality_inputs: Dict[str, Any] with preprocessed inputs per modality
e.g., {"vision": {"pixel_values": tensor, ...}}
dict: A mapping with keys `"input_ids"`, `"labels"`, `"loss_mask"`, `"attention_mask"`, `"position_ids"`, and `"modality_inputs"` as described above.
"""
if idx >= self._size:
raise IndexError(f"Index {idx} out of range for dataset of size {self._size}")
Expand Down
38 changes: 14 additions & 24 deletions src/megatron/bridge/data/mimo/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,21 @@ def get_mimo_dp_info(
mimo_cfg: "MimoParallelismConfig",
grids: Dict[str, "HyperCommGrid"],
) -> Tuple[int, int, bool, str]:
"""Get DP rank, size, data-loading responsibility, and loader module for MIMO.

Determines which module's DP settings to use for data loading based on
current rank's participation in heterogeneous deployment.

In heterogeneous mode, each rank uses its own module's DP settings.

Args:
mimo_cfg: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping from build_hypercomm_grids().

"""
Return the DP rank and size, whether the current rank must load data, and which module's DP settings to use for the current distributed rank.

Determines which module (from `grids`) contains the current global rank and computes the DP process-group rank and size. If the rank does not belong to any provided grid, returns defaults for a non-participating rank: (0, 1, False, MIMO_LANGUAGE_MODULE_KEY). `needs_data` is true for the language module when the rank is on the first or last PP stage; for other modules it is true only on the first PP stage.

Parameters:
mimo_cfg (MimoParallelismConfig): MIMO parallelism configuration (kept for API compatibility).
grids (Dict[str, HyperCommGrid]): Mapping from module name to its HyperCommGrid produced by build_hypercomm_grids().

Returns:
Tuple of (dp_rank, dp_size, needs_data, loader_module):
- dp_rank: This rank's position in DP group.
- dp_size: Size of DP group for data sharding.
- needs_data: Whether this rank needs to load data (first/last PP stage).
- loader_module: Which module's DP settings are being used.

Example:
>>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids
>>> grids = build_hypercomm_grids(mimo_cfg)
>>> dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids)
>>> if needs_data:
... # Build data loader with dp_rank and dp_size
... sampler = DistributedSampler(dataset, num_replicas=dp_size, rank=dp_rank)
Tuple[int, int, bool, str]: `(dp_rank, dp_size, needs_data, loader_module)` where
- `dp_rank`: this rank's index within its DP process group,
- `dp_size`: size of the DP process group used for data sharding,
- `needs_data`: `True` if the rank should load data, `False` otherwise,
- `loader_module`: name of the module whose DP settings should be used (or `MIMO_LANGUAGE_MODULE_KEY` for non-participating ranks).
"""
current_rank = dist.get_rank()

Expand Down
35 changes: 26 additions & 9 deletions src/megatron/bridge/data/mimo/hf_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,15 @@ def _load_tokenizer(self) -> Any:
return tokenizer

def _load_hf_dataset(self, split: str) -> Any:
"""Load a HuggingFace dataset split."""
"""
Load a specified split from the configured HuggingFace dataset.

Parameters:
split (str): Name of the split to load (e.g., "train", "validation", "test").

Returns:
dataset (Any): The loaded dataset split, or `None` if the requested split does not exist.
"""
try:
dataset = load_dataset(
self.hf_dataset_path,
Expand All @@ -146,7 +154,16 @@ def _build_split_dataset(
processors: Dict[str, Any],
tokenizer: Any,
) -> Optional[MimoDataset]:
"""Build dataset for a single split."""
"""
Create a MimoDataset for the specified HuggingFace dataset split.

Parameters:
split: The name of the dataset split to load (e.g., "train", "validation", "test").
target_samples: Maximum number of examples to include; if less than or equal to 0 the function returns `None`.

Returns:
MimoDataset: A dataset configured with the loaded split, processors, tokenizer, and provider settings, or `None` if the split is missing or `target_samples` is <= 0.
"""
if target_samples <= 0:
return None

Expand All @@ -170,14 +187,14 @@ def _build_split_dataset(
def build_datasets(
self, context: DatasetBuildContext
) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]:
"""Build train, validation, and test datasets.

Args:
context: Build context with sample counts.

"""
Build datasets for training, validation, and testing from the configured HuggingFace dataset.

Parameters:
context (DatasetBuildContext): Provides target sample counts for train, valid, and test builds.

Returns:
Tuple of (train_dataset, valid_dataset, test_dataset).
Any element can be None if split doesn't exist or sample count is 0.
tuple: (train_dataset, valid_dataset, test_dataset) where each element is a Dataset or `None` if the corresponding split is missing or the requested sample count is less than or equal to zero.
"""
processors = self._load_processors()
tokenizer = self._load_tokenizer()
Expand Down
56 changes: 23 additions & 33 deletions src/megatron/bridge/data/mimo/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,21 @@ def build_mimo_data_loaders(
valid_samples: int,
test_samples: int,
) -> Tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
"""Build MIMO data loaders with per-module DP settings.

Creates data loaders with DP-aware sampling based on the MIMO parallelism
configuration. Only ranks that need data (first/last PP stage) will get
non-None loaders.

Args:
cfg: Configuration container with MimoModelProvider as cfg.model.
train_state: Current training state.
mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider)
with get_collate_fn() method.
train_samples: Number of training samples.
valid_samples: Number of validation samples.
test_samples: Number of test samples.

"""
Build MIMO data loaders using per-module data-parallel (DP) sampling derived from the MIMO parallelism configuration.

Parameters:
cfg: Configuration container whose `model` must be a `MimoModelProvider` with a populated `_grids` attribute and a non-`None` `mimo_parallelism_config`.
mimo_provider: MIMO dataset provider that exposes `build_datasets(context)` and `get_collate_fn()`, and provides loader settings (`num_workers`, `pin_memory`, `drop_last`).
train_samples (int): Number of training samples to request when building datasets.
valid_samples (int): Number of validation samples to request when building datasets.
test_samples (int): Number of test samples to request when building datasets.

Returns:
Tuple of (train_loader, valid_loader, test_loader).
Returns (None, None, None) if this rank doesn't need data.

Tuple of `(train_loader, valid_loader, test_loader)`. Each element is a `torch.utils.data.DataLoader` configured with DP-aware `DistributedSampler`, or `None` when no dataset was built or when the current rank does not require data (in which case all three are `None`).

Raises:
ValueError: If cfg.model is not MimoModelProvider or mimo_parallelism_config is None.

Example:
>>> from megatron.bridge.data.mimo import MockMimoProvider, build_mimo_data_loaders
>>> provider = MockMimoProvider(
... seq_length=2048,
... processor_paths={"vision": "openai/clip-vit-large-patch14"},
... tokenizer_path="meta-llama/Llama-2-7b-hf",
... special_token_ids={"vision": 32000},
... modality_configs={"vision": {"type": "image", "width": 224, "height": 224}},
... )
>>> train_loader, valid_loader, test_loader = build_mimo_data_loaders(
... cfg, train_state, provider,
... train_samples=10000, valid_samples=1000, test_samples=1000,
... )
ValueError: If `cfg.model` is not a `MimoModelProvider`, if `cfg.model.mimo_parallelism_config` is `None`, or if `cfg.model._grids` is `None` (indicating model infra has not been built).
"""
from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider

Expand Down Expand Up @@ -105,6 +85,16 @@ def build_mimo_data_loaders(
micro_batch_size = cfg.train.micro_batch_size

def _make_loader(dataset, shuffle: bool = True) -> Optional[DataLoader]:
"""
Create a DataLoader for the given dataset configured with a DistributedSampler for the current MIMO data-parallel replica.

Parameters:
dataset: The dataset to load; if `None`, no loader is created and `None` is returned.
shuffle (bool): Whether the distributed sampler should shuffle sample order for this split.

Returns:
DataLoader or `None`: A DataLoader using the configured `DistributedSampler` and provider settings, or `None` when `dataset` is `None`.
"""
if dataset is None:
return None
sampler = torch.utils.data.DistributedSampler(
Expand Down
24 changes: 22 additions & 2 deletions src/megatron/bridge/data/mimo/mock_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,17 @@


def _generate_random_image(width: int, height: int, rng: np.random.Generator) -> Image.Image:
"""Generate a random RGB image."""
"""
Create a PIL RGB image of the given pixel dimensions with random pixel values.

Parameters:
width (int): Image width in pixels.
height (int): Image height in pixels.
rng (np.random.Generator): NumPy random number generator used to sample uint8 RGB pixel values deterministically when seeded.

Returns:
Image.Image: A PIL Image in "RGB" mode containing random pixels.
"""
pixels = rng.integers(low=0, high=256, size=(height, width, 3), dtype=np.uint8)
return Image.fromarray(pixels, mode="RGB")

Expand Down Expand Up @@ -110,7 +120,17 @@ def _load_processors(self) -> Dict[str, Any]:
return processors

def _load_tokenizer(self) -> Any:
"""Load HuggingFace tokenizer."""
"""
Load and cache the HuggingFace tokenizer specified by `tokenizer_path`.

If a tokenizer is already cached on this instance, the cached object is returned. Ensures the tokenizer has a `pad_token` by assigning `eos_token` to `pad_token` when missing and stores the loaded tokenizer in `self._tokenizer`.

Raises:
ValueError: If `self.tokenizer_path` is empty or falsy.

Returns:
tokenizer: The loaded HuggingFace tokenizer instance.
"""
if self._tokenizer is not None:
return self._tokenizer

Expand Down
Loading
Loading