Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 1028 files
17 changes: 10 additions & 7 deletions src/megatron/bridge/data/mimo/dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from typing import TYPE_CHECKING, Dict, Tuple

import torch.distributed as dist
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY


if TYPE_CHECKING:
from megatron.core.hyper_comm_grid import HyperCommGrid

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


Expand All @@ -17,23 +20,23 @@ def get_mimo_dp_info(
grids: Dict[str, "HyperCommGrid"],
) -> Tuple[int, int, bool, str]:
"""Get DP rank, size, data-loading responsibility, and loader module for MIMO.

Determines which module's DP settings to use for data loading based on
current rank's participation in heterogeneous deployment.

In heterogeneous mode, each rank uses its own module's DP settings.

Args:
mimo_cfg: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping from build_hypercomm_grids().

Returns:
Tuple of (dp_rank, dp_size, needs_data, loader_module):
- dp_rank: This rank's position in DP group.
- dp_size: Size of DP group for data sharding.
- needs_data: Whether this rank needs to load data (first/last PP stage).
- loader_module: Which module's DP settings are being used.

Example:
>>> from megatron.bridge.models.mimo.mimo_builder import build_hypercomm_grids
>>> grids = build_hypercomm_grids(mimo_cfg)
Expand All @@ -55,7 +58,7 @@ def get_mimo_dp_info(

if my_grid is None or my_module is None:
# Rank doesn't participate in any module
return 0, 1, False, "llm"
return 0, 1, False, MIMO_LANGUAGE_MODULE_KEY

dp_rank = my_grid.get_pg(["dp"]).rank()
dp_size = my_grid.get_pg(["dp"]).size()
Expand All @@ -64,7 +67,7 @@ def get_mimo_dp_info(
pp_rank = pp_group.rank()
pp_size = pp_group.size()

if my_module == "llm":
if my_module == MIMO_LANGUAGE_MODULE_KEY:
needs_data = (pp_rank == 0) or (pp_rank == pp_size - 1)
else:
needs_data = pp_rank == 0
Expand Down
12 changes: 7 additions & 5 deletions src/megatron/bridge/models/mimo/mimo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass, field
from typing import Optional

from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY


@dataclass
class ModuleParallelismConfig:
Expand Down Expand Up @@ -62,7 +64,7 @@ class MimoParallelismConfig:
Note: Phase 1 only supports heterogeneous deployment where each module
can have different parallelism configurations and rank offsets.

The LLM module must be named "llm" in module_parallelisms.
The language module must be named MIMO_LANGUAGE_MODULE_KEY ("language") in module_parallelisms.
"""

module_parallelisms: dict[str, ModuleParallelismConfig]
Expand Down Expand Up @@ -133,10 +135,10 @@ def is_power_of_two(n: int) -> bool:
# Validate encoder DP >= LLM DP for embedding alignment
# Encoder modules produce embeddings consumed by LLM. If encoder DP < LLM DP,
# the same encoder batch would need to align with different LLM batches, which fails.
llm_dp = self.module_parallelisms["llm"].data_parallel_size
llm_dp = self.module_parallelisms[MIMO_LANGUAGE_MODULE_KEY].data_parallel_size
if llm_dp is not None:
for name, p in self.module_parallelisms.items():
if name == "llm":
if name == MIMO_LANGUAGE_MODULE_KEY:
continue
encoder_dp = p.data_parallel_size
if encoder_dp is not None and encoder_dp < llm_dp:
Expand All @@ -152,9 +154,9 @@ def finalize(self, world_size: int) -> None:
world_size: Total number of ranks in the distributed world.
MIMO requires a distributed environment, so this must always be provided.
"""
if "llm" not in self.module_parallelisms:
if MIMO_LANGUAGE_MODULE_KEY not in self.module_parallelisms:
raise ValueError(
f"LLM module 'llm' must be in module_parallelisms. "
f"Language module '{MIMO_LANGUAGE_MODULE_KEY}' must be in module_parallelisms. "
f"Found modules: {list(self.module_parallelisms.keys())}"
)

Expand Down
17 changes: 11 additions & 6 deletions src/megatron/bridge/models/mimo/mimo_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,22 @@

Note: This module only supports DDP wrapping. FSDP is not yet implemented.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional

from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid


if TYPE_CHECKING:
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.models.mimo import MimoModel
from megatron.core.process_groups_config import ProcessGroupCollection

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


Expand Down Expand Up @@ -44,9 +49,9 @@ def wrap_mimo_model_distributed(

# Wrap language model if present and rank participates
if mimo_model.language_model is not None:
llm_grid = grids.get("llm")
llm_grid = grids.get(MIMO_LANGUAGE_MODULE_KEY)
if llm_grid is not None and is_current_rank_in_grid(llm_grid):
llm_pg = pg_collections.get("llm")
llm_pg = pg_collections.get(MIMO_LANGUAGE_MODULE_KEY)
if llm_pg is not None:
mimo_model.language_model = DistributedDataParallel(
config=mimo_model.language_model.config,
Expand All @@ -56,7 +61,7 @@ def wrap_mimo_model_distributed(
)

# Wrap modality submodules
if hasattr(mimo_model, 'modality_submodules'):
if hasattr(mimo_model, "modality_submodules"):
for module_name, submodule in mimo_model.modality_submodules.items():
if submodule is None:
continue
Expand All @@ -74,11 +79,11 @@ def wrap_mimo_model_distributed(
# Note: We use the first encoder's config for DDP bucket sizing.
# This assumes all encoders in a modality submodule share similar
# parallelism settings, which is typical for MIMO models.
if hasattr(submodule, 'encoders') and submodule.encoders:
if hasattr(submodule, "encoders") and submodule.encoders:
encoder_key = next(iter(submodule.encoders.keys()))
first_encoder = submodule.encoders[encoder_key]
if not hasattr(first_encoder, 'config'):

if not hasattr(first_encoder, "config"):
raise AttributeError(
f"Encoder '{encoder_key}' in modality '{module_name}' does not have "
f"a 'config' attribute. Encoders must be MegatronModule subclasses."
Expand Down
30 changes: 24 additions & 6 deletions src/megatron/bridge/models/mimo/mimo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.models.mimo import MimoModel
from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
Expand Down Expand Up @@ -60,6 +61,7 @@ class MimoModelInfra:
topology: Dict[str, List[str]]
pg_collections: Dict[str, Optional[ProcessGroupCollection]]
participating_modules: List[str]
module_output_ndim: Dict[str, int] = field(default_factory=dict)


@dataclass
Expand All @@ -83,7 +85,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]):
Example:
>>> mimo_parallelism_config = MimoParallelismConfig(
... module_parallelisms={
... "llm": ModuleParallelismConfig(tensor_model_parallel_size=8),
... "language": ModuleParallelismConfig(tensor_model_parallel_size=8),
... "clip_encoder": ModuleParallelismConfig(tensor_model_parallel_size=2),
... }
... )
Expand All @@ -108,10 +110,15 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]):
mimo_parallelism_config: Optional[MimoParallelismConfig] = None

# Module data-flow DAG for MultiModulePipelineCommunicator.
# If None, auto-derived as: all modality_submodules → "llm" (terminal).
# Set explicitly for non-standard topologies (e.g., llm → generator).
# If None, auto-derived as: all modality_submodules → language module (terminal).
# Set explicitly for non-standard topologies (e.g., language → generator).
topology: Optional[Dict[str, List[str]]] = None

# Output tensor dimensionality per module for bridge communicator routing.
# Vision/audio encoders typically produce 2D [S, H]; language modules produce 3D [S, B, H].
# If None, auto-derived: language module → 3, all others → 2.
module_output_ndim: Optional[Dict[str, int]] = None

# Cached grids after build_model() - used by data loading
_grids: Optional[Dict[str, "HyperCommGrid"]] = field(default=None, repr=False)

Expand Down Expand Up @@ -150,18 +157,30 @@ def build_infra(self) -> MimoModelInfra:
if self.topology is not None:
topology = self.topology
else:
topology = {name: ["llm"] for name in self.modality_submodules_spec} | {"llm": []}
topology = {name: [MIMO_LANGUAGE_MODULE_KEY] for name in self.modality_submodules_spec} | {
MIMO_LANGUAGE_MODULE_KEY: []
}

# Cache grids for later use (e.g., data loading)
object.__setattr__(self, "_grids", grids)

participating_modules = [name for name, pg in pg_collections.items() if pg is not None]

# Derive module output tensor dimensionality if not explicitly configured.
if self.module_output_ndim is not None:
output_ndim = self.module_output_ndim
else:
output_ndim = {
name: 3 if name == MIMO_LANGUAGE_MODULE_KEY else 2
for name in grids
}

return MimoModelInfra(
module_to_grid_map=grids,
topology=topology,
pg_collections=pg_collections,
participating_modules=participating_modules,
module_output_ndim=output_ndim,
)

def _get_pg_collections_from_grids(
Expand Down Expand Up @@ -289,7 +308,7 @@ def provide(
# Inject pg_collection into language model spec
language_spec = self.language_model_spec
if self.mimo_parallelism_config:
llm_pg = infra.pg_collections.get("llm")
llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY)
if llm_pg is not None:
language_spec = self._inject_pg_collection_into_language_spec(
language_spec,
Expand All @@ -312,7 +331,6 @@ def provide(
modality_submodules_spec=modality_specs,
special_token_ids=self.special_token_ids,
module_to_grid_map=(infra.module_to_grid_map if self.mimo_parallelism_config is not None else None),
language_module_key="llm" if self.mimo_parallelism_config is not None else None,
)

mimo_model = MimoModel(mimo_model_config)
Expand Down
5 changes: 3 additions & 2 deletions src/megatron/bridge/training/mimo_parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.distributed as dist
from megatron.core.distributed.finalize_model_grads import finalize_model_grads as _finalize_model_grads
from megatron.core.models.mimo import MimoModel
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

from megatron.bridge.models.mimo.mimo_provider import MimoModelInfra

Expand Down Expand Up @@ -94,7 +95,7 @@ def get_module_to_grid_tuple(
continue

# Get the actual module from the unwrapped model
if module_name == "llm":
if module_name == MIMO_LANGUAGE_MODULE_KEY:
module = unwrapped_model.language_model
elif hasattr(unwrapped_model, "modality_submodules") and module_name in unwrapped_model.modality_submodules:
module = unwrapped_model.modality_submodules[module_name]
Expand Down Expand Up @@ -128,7 +129,7 @@ def build_pg_collection_for_schedule(infra: MimoModelInfra):
module_pgs = {k: v for k, v in infra.pg_collections.items() if v is not None}
if not module_pgs:
raise ValueError("module_pgs dict cannot be empty")
language_model_module_name = "llm" if "llm" in module_pgs else None
language_model_module_name = MIMO_LANGUAGE_MODULE_KEY if MIMO_LANGUAGE_MODULE_KEY in module_pgs else None
return MultiModuleProcessGroupCollection(
module_pgs=module_pgs,
language_model_module_name=language_model_module_name,
Expand Down
5 changes: 3 additions & 2 deletions src/megatron/bridge/training/mimo_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from megatron.core.models.mimo import MimoModel
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY

from megatron.bridge.training.mimo_parallel_utils import unwrap_mimo_model
from megatron.bridge.training.state import GlobalState
Expand Down Expand Up @@ -135,7 +136,7 @@ def forward_step(
needs_data = True
if mimo_model.role is not None:
if mimo_model.role.has_language_module:
module_name = mimo_model.role.language_module_name
module_name = MIMO_LANGUAGE_MODULE_KEY
is_first_stage = mimo_model.role.is_first_stage(module_name)
is_last_stage = mimo_model.role.is_last_stage(module_name)
needs_data = is_first_stage or is_last_stage
Expand Down Expand Up @@ -182,7 +183,7 @@ def forward_step(
if mimo_model.role is None:
is_last_stage = True
elif mimo_model.role.has_language_module:
is_last_stage = mimo_model.role.is_last_stage(mimo_model.role.language_module_name)
is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY)
else:
is_last_stage = False

Expand Down
6 changes: 1 addition & 5 deletions src/megatron/bridge/training/pretrain_mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def setup_mimo(
mimo_infra.topology,
model_config,
dim_mapping={"s": 0, "b": 1, "h": 2}, # SBH mapping - matches MimoModel output
module_output_ndim=mimo_infra.module_output_ndim,
)

# Build pg_collection for schedule
Expand Down Expand Up @@ -261,11 +262,6 @@ def pretrain_mimo(
"MimoModelConfig.module_to_grid_map must be set at model construction time. "
"Ensure MimoModelProvider.provide() passes module_to_grid_map for MIMO parallelism."
)
assert unwrapped_model.mimo_config.language_module_key is not None, (
"MimoModelConfig.language_module_key must be set at model construction time. "
"Ensure MimoModelProvider.provide() sets language_module_key for MIMO parallelism."
)

logger.info(f"Rank {dist.get_rank()}: Creating MimoOptimizer")

# Create MimoOptimizer using the factory function
Expand Down
5 changes: 3 additions & 2 deletions src/megatron/bridge/training/train_mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch
import torch.distributed as dist
from megatron.core.models.mimo.config.role import MIMO_LANGUAGE_MODULE_KEY
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.pipeline_parallel.schedules import forward_backward_pipelining_without_interleaving
from megatron.core.utils import get_model_config
Expand Down Expand Up @@ -144,10 +145,10 @@ def train_step_mimo(
if mimo_model.role is None:
is_last_stage = True
elif mimo_model.role.has_language_module:
is_last_stage = mimo_model.role.is_last_stage(mimo_model.role.language_module_name)
is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY)

if is_last_stage:
llm_pg = infra.pg_collections.get("llm") if infra.pg_collections else None
llm_pg = infra.pg_collections.get(MIMO_LANGUAGE_MODULE_KEY) if infra.pg_collections else None
for key in losses_reduced[0].keys():
val = [x[key].view(-1) for x in losses_reduced]
if val[0].numel() == 2:
Expand Down
Loading
Loading