Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 55 files
+6 −1 .github/workflows/cicd-main.yml
+1 −0 .github/workflows/claude-copy-to-main.yml
+7 −3 .github/workflows/claude_review.yml
+4 −1 .gitignore
+12 −0 examples/post_training/modelopt/README.md
+9 −1 examples/post_training/modelopt/export.py
+13 −1 examples/post_training/modelopt/quantize.py
+7 −2 examples/post_training/modelopt/quantize.sh
+6 −11 megatron/core/dist_checkpointing/strategies/async_utils.py
+24 −12 megatron/core/extensions/transformer_engine_spec_provider.py
+9 −0 megatron/core/hyper_comm_grid.py
+21 −19 megatron/core/models/backends.py
+6 −16 megatron/core/models/gpt/moe_module_specs.py
+3 −0 megatron/core/models/mimo/__init__.py
+2 −1 megatron/core/models/mimo/config/__init__.py
+8 −1 megatron/core/models/mimo/config/base_configs.py
+172 −0 megatron/core/models/mimo/config/role.py
+251 −27 megatron/core/models/mimo/model/base.py
+379 −0 megatron/core/models/mimo/optimizer.py
+9 −114 megatron/core/models/mimo/submodules/audio.py
+154 −22 megatron/core/models/mimo/submodules/base.py
+4 −135 megatron/core/models/mimo/submodules/vision.py
+5 −1 megatron/core/models/vision/multimodal_projector.py
+40 −14 megatron/core/pipeline_parallel/bridge_communicator.py
+29 −18 megatron/core/pipeline_parallel/multimodule_communicator.py
+68 −19 megatron/core/transformer/moe/moe_layer.py
+15 −1 megatron/core/transformer/spec_utils.py
+0 −75 megatron/legacy/fused_kernels/__init__.py
+0 −17 megatron/legacy/fused_kernels/compat.h
+0 −0 megatron/legacy/fused_kernels/tests/__init__.py
+0 −389 megatron/legacy/fused_kernels/tests/test_fused_kernels.py
+0 −103 megatron/legacy/fused_kernels/type_shim.h
+0 −4 megatron/legacy/model/module.py
+3 −3 megatron/training/arguments.py
+1 −1 megatron/training/config/common_config.py
+2 −2 megatron/training/config/training_config.py
+11 −45 megatron/training/initialize.py
+4 −3 megatron/training/training.py
+2 −1 megatron/training/utils.py
+16 −20 tests/unit_tests/dist_checkpointing/models/test_moe_experts.py
+48 −0 tests/unit_tests/dist_checkpointing/test_async_utils_shutdown.py
+812 −0 tests/unit_tests/models/test_mimo_1f1b_schedule.py
+30 −0 tests/unit_tests/models/test_mimo_audio_submodules.py
+273 −0 tests/unit_tests/models/test_mimo_checkpoint.py
+302 −308 tests/unit_tests/models/test_mimo_model.py
+47 −0 tests/unit_tests/models/test_mimo_role.py
+39 −0 tests/unit_tests/models/test_mimo_submodules.py
+44 −0 tests/unit_tests/pipeline_parallel/test_bridge_communicator.py
+129 −0 tests/unit_tests/transformer/test_spec_utils.py
+0 −2 tools/checkpoint/loader_base.py
+0 −2 tools/checkpoint/loader_legacy.py
+0 −2 tools/checkpoint/loader_llama_mistral.py
+0 −2 tools/checkpoint/loader_mixtral_hf.py
+0 −2 tools/checkpoint/saver_base.py
+0 −2 tools/checkpoint/saver_legacy.py
232 changes: 102 additions & 130 deletions src/megatron/bridge/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.utils.data import DataLoader

from megatron.bridge.data.samplers import build_pretraining_data_loader
from megatron.bridge.training.config import ConfigContainer, DatasetProvider, GPTDatasetConfig
from megatron.bridge.training.config import ConfigContainer, GPTDatasetConfig
from megatron.bridge.training.state import TrainState
from megatron.bridge.training.utils.sig_utils import DistributedSignalHandler
from megatron.bridge.utils.common_utils import print_rank_0
Expand Down Expand Up @@ -178,104 +178,128 @@ def build_train_valid_test_data_loaders(
Returns:
A tuple (train_dataloader, valid_dataloader, test_dataloader).
"""
specialized_builder = _resolve_data_loader_builder(cfg)
if specialized_builder is not None:
train_dataloader, valid_dataloader, test_dataloader = specialized_builder(
# Check for MIMO path
from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider
from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider

if isinstance(cfg.model, MimoModelProvider):
if not isinstance(cfg.dataset, MimoDatasetProvider):
raise ValueError(
"MIMO models require cfg.dataset to be a MimoDatasetProvider. "
"Use HFMimoDatasetProvider, MockMimoProvider, or a subclass of MimoDatasetProvider."
)
from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders

train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg)
train_dataloader, valid_dataloader, test_dataloader = build_mimo_data_loaders(
cfg=cfg,
train_state=train_state,
build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider,
dp_group=dp_group,
mimo_provider=cfg.dataset,
train_samples=train_samples,
valid_samples=valid_samples,
test_samples=test_samples,
)
else:
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

print_rank_0("> building train, validation, and test datasets ...")

# Construct the data pipeline
# Build datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
cfg=cfg, build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider
)
# Sync train_state flags across all ranks.
# Use all_reduce(MAX) since some ranks may not have loaders in heterogeneous MIMO.
do_train = train_dataloader is not None and cfg.train.train_iters > 0
do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0
do_test = test_dataloader is not None and cfg.train.eval_iters > 0
Comment on lines +205 to +207
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the validation config here.

The rest of this module reads evaluation settings from cfg.validation; using cfg.train.eval_iters on the MIMO path can raise or permanently disable do_valid/do_test.

Suggested fix
-        do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0
-        do_test = test_dataloader is not None and cfg.train.eval_iters > 0
+        do_valid = valid_dataloader is not None and cfg.validation.eval_iters > 0
+        do_test = test_dataloader is not None and cfg.validation.eval_iters > 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/loaders.py` around lines 205 - 207, The do_valid and
do_test flags are using cfg.train.eval_iters but the module reads evaluation
settings from cfg.validation; update the checks in the block that sets
do_train/do_valid/do_test so that do_valid and do_test use
cfg.validation.eval_iters (or the appropriate cfg.validation field) instead of
cfg.train.eval_iters, keeping do_train unchanged; adjust references to the
variables do_valid and do_test in the surrounding code if needed to ensure
evaluation behavior follows cfg.validation.eval_iters.

flags = torch.tensor([int(do_train), int(do_valid), int(do_test)], dtype=torch.long, device="cuda")
torch.distributed.all_reduce(flags, op=torch.distributed.ReduceOp.MAX)
train_state.do_train = flags[0].item()
train_state.do_valid = flags[1].item()
train_state.do_test = flags[2].item()

exit_signal = cfg.train.exit_signal
return train_dataloader, valid_dataloader, test_dataloader

def worker_init_fn(_):
DistributedSignalHandler(exit_signal).__enter__()
(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None
print_rank_0("> building train, validation, and test datasets ...")

# Resolve DP rank/size from provided data-parallel process group
dp_rank = torch.distributed.get_rank(group=dp_group)
dp_size = torch.distributed.get_world_size(group=dp_group)
# Construct the data pipeline
# Build datasets.
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
cfg=cfg, build_train_valid_test_datasets_provider=build_train_valid_test_datasets_provider
)

# Build dataloders.
train_dataloader = build_pretraining_data_loader(
train_ds,
train_state.consumed_train_samples,
exit_signal = cfg.train.exit_signal

def worker_init_fn(_):
DistributedSignalHandler(exit_signal).__enter__()

maybe_worker_init_fn = worker_init_fn if cfg.train.exit_signal_handler_for_dataloader else None

# Resolve DP rank/size from provided data-parallel process group
dp_rank = torch.distributed.get_rank(group=dp_group)
dp_size = torch.distributed.get_world_size(group=dp_group)

# Build dataloders.
train_dataloader = build_pretraining_data_loader(
train_ds,
train_state.consumed_train_samples,
cfg.dataset.dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=train_ds.collate_fn if hasattr(train_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)
if cfg.validation.skip_train and cfg.validation.eval_iters > 0:
valid_dataloader = build_pretraining_data_loader(
valid_ds,
0,
cfg.dataset.dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=train_ds.collate_fn if hasattr(train_ds, "collate_fn") else None,
collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)
elif cfg.validation.eval_iters > 0:
val_dataloader_type = "cyclic" if isinstance(cfg.dataset, GPTDatasetConfig) else cfg.dataset.dataloader_type
valid_dataloader = build_pretraining_data_loader(
valid_ds,
train_state.consumed_valid_samples,
val_dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)
if cfg.validation.skip_train and cfg.validation.eval_iters > 0:
valid_dataloader = build_pretraining_data_loader(
valid_ds,
0,
cfg.dataset.dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)
elif cfg.validation.eval_iters > 0:
val_dataloader_type = (
"cyclic" if isinstance(cfg.dataset, GPTDatasetConfig) else cfg.dataset.dataloader_type
)
valid_dataloader = build_pretraining_data_loader(
valid_ds,
train_state.consumed_valid_samples,
val_dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=valid_ds.collate_fn if hasattr(valid_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)

if cfg.validation.eval_iters > 0:
test_dataloader = build_pretraining_data_loader(
test_ds,
0,
cfg.dataset.dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=test_ds.collate_fn if hasattr(test_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)
if cfg.validation.eval_iters > 0:
test_dataloader = build_pretraining_data_loader(
test_ds,
0,
cfg.dataset.dataloader_type,
cfg.train.micro_batch_size,
cfg.dataset.num_workers,
cfg.dataset.data_sharding,
worker_init_fn=maybe_worker_init_fn,
collate_fn=test_ds.collate_fn if hasattr(test_ds, "collate_fn") else None,
pin_memory=cfg.dataset.pin_memory,
persistent_workers=cfg.dataset.persistent_workers,
data_parallel_rank=dp_rank,
data_parallel_size=dp_size,
global_batch_size=cfg.train.global_batch_size,
)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and cfg.train.train_iters > 0
Expand All @@ -292,58 +316,6 @@ def worker_init_fn(_):
return train_dataloader, valid_dataloader, test_dataloader


def _resolve_data_loader_builder(
cfg: ConfigContainer,
) -> Optional[Callable[..., tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]]]:
"""Resolve an optional model-specific data loader builder.

This acts as a lightweight dispatch registry so model-specific loader logic
does not need to be hard-coded inline in the generic data loader path.
"""
from megatron.bridge.models.mimo.mimo_provider import MimoModelProvider

specialized_builders: dict[
type,
Callable[..., tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]],
] = {
MimoModelProvider: _build_mimo_train_valid_test_data_loaders,
}

for model_type, builder in specialized_builders.items():
if isinstance(cfg.model, model_type):
return builder
return None


def _build_mimo_train_valid_test_data_loaders(
cfg: ConfigContainer,
train_state: TrainState,
build_train_valid_test_datasets_provider: Callable, # Unused; kept for common builder signature.
dp_group: torch.distributed.ProcessGroup, # Unused; MIMO determines DP per module.
) -> tuple[Optional[DataLoader], Optional[DataLoader], Optional[DataLoader]]:
"""Build train/valid/test loaders for MIMO models via the specialized MIMO path."""
del build_train_valid_test_datasets_provider, dp_group

if not isinstance(cfg.dataset, DatasetProvider) or not callable(getattr(cfg.dataset, "get_collate_fn", None)):
raise ValueError(
"MIMO models require cfg.dataset to implement DatasetProvider.build_datasets() "
"and a MIMO-compatible get_collate_fn() method (e.g., HFMimoDatasetProvider "
"or MockMimoProvider)."
)

from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders

train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg)
return build_mimo_data_loaders(
cfg=cfg,
train_state=train_state,
mimo_provider=cfg.dataset,
train_samples=train_samples,
valid_samples=valid_samples,
test_samples=test_samples,
)


def build_train_valid_test_data_iterators(
cfg: ConfigContainer,
train_state: TrainState,
Expand Down
14 changes: 8 additions & 6 deletions src/megatron/bridge/data/mimo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""MIMO multi-encoder data loading utilities."""

# Providers
from megatron.bridge.data.mimo.collate import mimo_collate_fn
from megatron.bridge.data.mimo.dataset import MimoDataset
from megatron.bridge.data.mimo.collate import mimo_collate_fn
from megatron.bridge.data.mimo.dp_utils import get_mimo_dp_info
from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider
from megatron.bridge.data.mimo.loaders import build_mimo_data_loaders
from megatron.bridge.data.mimo.mock_provider import MockMimoProvider

# Providers
from megatron.bridge.data.mimo.base_provider import MimoDatasetProvider
from megatron.bridge.data.mimo.hf_provider import HFMimoDatasetProvider
from megatron.bridge.data.mimo.mock_provider import MockMimoProvider

__all__ = [
# Core
"MimoDataset",
"mimo_collate_fn",
# Providers
# Providers (base + implementations)
"MimoDatasetProvider",
"HFMimoDatasetProvider",
"MockMimoProvider",
# Utilities
Expand Down
63 changes: 63 additions & 0 deletions src/megatron/bridge/data/mimo/base_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Base class for MIMO dataset providers."""

from __future__ import annotations

from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple

from torch.utils.data import Dataset

from megatron.bridge.training.config import DatasetBuildContext, DatasetProvider


@dataclass(kw_only=True)
class MimoDatasetProvider(DatasetProvider):
"""Abstract base class for MIMO dataset providers.

All MIMO dataset providers must inherit from this class and implement
the required methods. This ensures a consistent interface for MIMO
data loading.

Required methods:
- build_datasets: Build train/valid/test datasets
- get_collate_fn: Return the collate function for batching

Example:
>>> class MyMimoProvider(MimoDatasetProvider):
... def build_datasets(self, context):
... # Build and return datasets
... return train_ds, valid_ds, test_ds
...
... def get_collate_fn(self):
... # Return collate function
... return my_collate_fn
"""

@abstractmethod
def build_datasets(
self, context: DatasetBuildContext
) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]:
"""Build train, validation, and test datasets.

Args:
context: Build context with sample counts.

Returns:
Tuple of (train_dataset, valid_dataset, test_dataset).
Any element can be None if not needed.
"""
...

@abstractmethod
def get_collate_fn(self) -> Callable:
"""Return the collate function for batching.

The collate function should handle the modality_inputs dict
and batch them appropriately for the model.

Returns:
Callable that takes a list of samples and returns a batch dict.
"""
...
2 changes: 2 additions & 0 deletions src/megatron/bridge/data/mimo/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def mimo_collate_fn(
labels = torch.stack([item["labels"] for item in batch])
attention_mask = torch.stack([item["attention_mask"] for item in batch])
position_ids = torch.stack([item["position_ids"] for item in batch])
loss_mask = torch.stack([item["loss_mask"] for item in batch])

Comment on lines +62 to 63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update mimo_collate_fn docstring to include loss_mask.

Line 62 and Line 114 make loss_mask part of the required collated payload, but the Args/Returns/example contract still omits it.

📝 Suggested doc update
     Args:
         batch: List of examples from MimoDataset, each containing:
             - input_ids: Token IDs with placeholder tokens
             - labels: Labels for causal LM training
+            - loss_mask: Per-token loss mask
             - attention_mask: Attention mask
             - position_ids: Position indices
             - modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
@@
     Returns:
         Dict containing:
             - input_ids: (batch, seq) stacked token IDs
             - labels: (batch, seq) stacked labels
+            - loss_mask: (batch, seq) stacked per-token loss mask
             - attention_mask: (batch, seq) attention mask
             - position_ids: (batch, seq) position indices
             - modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors

As per coding guidelines For interfaces that may be used outside a file, prefer docstrings over comments.

Also applies to: 114-114

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/collate.py` around lines 62 - 63, The docstring
for mimo_collate_fn is missing documentation for the newly collated "loss_mask"
field; update the mimo_collate_fn docstring (Args, Returns and example sections)
to list loss_mask as a required tensor in the collated payload, describe its
shape and purpose (e.g., mask for loss computation), and show it in the example
return structure alongside existing keys so external callers see the API
contract.

# Collate modality inputs
modality_inputs: Dict[str, Dict[str, Any]] = {}
Expand Down Expand Up @@ -110,6 +111,7 @@ def mimo_collate_fn(
return {
"input_ids": input_ids,
"labels": labels,
"loss_mask": loss_mask,
"attention_mask": attention_mask,
"position_ids": position_ids,
"modality_inputs": modality_inputs,
Expand Down
Loading
Loading