Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/megatron/bridge/models/mimo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from megatron.bridge.models.mimo.llava_provider import LlavaMimoProvider
from megatron.bridge.models.mimo.mimo_config import (
Expand All @@ -12,9 +12,9 @@


__all__ = [
"LlavaMimoProvider",
"MimoModelInfra",
"MimoModelProvider",
"MimoParallelismConfig",
"ModuleParallelismConfig",
"MimoModelProvider",
"MimoModelInfra",
"LlavaMimoProvider",
]
60 changes: 57 additions & 3 deletions src/megatron/bridge/models/mimo/mimo_builder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import torch.distributed as dist

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


if TYPE_CHECKING:
from megatron.core.process_groups_config import HyperCommGrid
from megatron.core.hyper_comm_grid import HyperCommGrid


def build_hypercomm_grids(
Expand Down Expand Up @@ -56,3 +58,55 @@ def build_hypercomm_grids(
def _default_topology(mimo_parallelism_config: MimoParallelismConfig) -> Dict[str, List[str]]:
"""Infer a default multi-encoder -> LLM topology."""
return {name: ["llm"] for name in mimo_parallelism_config.module_names if name != "llm"} | {"llm": []}


def create_embedding_and_position_groups(
pp_group: dist.ProcessGroup,
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
"""Create embedding-related process groups from PP group ranks.

Following MCore semantics:
- pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings
- embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings

IMPORTANT: This calls dist.new_group which is a collective operation.
Must be called on all ranks that could participate.

Note: VPP (virtual_pipeline_model_parallel_size > 1) is not supported.
With VPP, pp_ranks[0]/pp_ranks[-1] do not reliably identify the stages
that own embeddings. The caller is responsible for asserting VPP is disabled.

Args:
pp_group: The pipeline parallel process group.

Returns:
Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None.
"""
if pp_group is None:
return None, None

pp_ranks = sorted(dist.get_process_group_ranks(pp_group))

# Position embeddings only on first PP stage
pos_embd_ranks = [pp_ranks[0]]
pos_embd_pg = dist.new_group(ranks=pos_embd_ranks)

# Word embeddings on first and last PP stages (for tied embeddings)
embd_ranks = [pp_ranks[0]]
if len(pp_ranks) > 1 and pp_ranks[-1] != pp_ranks[0]:
embd_ranks.append(pp_ranks[-1])
embd_pg = dist.new_group(ranks=embd_ranks)

return pos_embd_pg, embd_pg


def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool:
"""Check if the current rank participates in this grid.

Args:
grid: A HyperCommGrid instance.

Returns:
True if dist.get_rank() is within the grid's rank range.
"""
return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size)
38 changes: 34 additions & 4 deletions src/megatron/bridge/models/mimo/mimo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

Expand Down Expand Up @@ -86,18 +86,48 @@ def total_world_size(self) -> int:
def _validate_heterogeneous(self) -> None:
"""Validate heterogeneous deployment: no overlapping rank ranges."""
ranges = []
for parallelism in self.module_parallelisms.values():
for name, parallelism in self.module_parallelisms.items():
if parallelism.data_parallel_size is None:
raise ValueError("data_parallel_size must be set for heterogeneous deployment.")
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks))
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks, name))

ranges.sort()
ranges.sort(key=lambda x: x[0])
for idx in range(1, len(ranges)):
prev_end = ranges[idx - 1][1]
cur_start = ranges[idx][0]
if cur_start < prev_end:
raise ValueError("rank_offset ranges overlap in heterogeneous deployment.")

# Check for gaps between modules (likely misconfiguration)
# Gaps in the middle are errors; leading gaps (rank_offset > 0) are warnings
if ranges:
min_rank = ranges[0][0] # Already sorted by rank_offset
max_rank = ranges[-1][1]

# Collect all covered ranks
covered_ranks = set()
for parallelism in self.module_parallelisms.values():
start = parallelism.rank_offset
end = start + parallelism.total_ranks
covered_ranks.update(range(start, end))

# Check for gaps between min and max (error - likely misconfiguration)
expected_middle = set(range(min_rank, max_rank))
gaps_in_middle = expected_middle - covered_ranks
if gaps_in_middle:
raise ValueError(
f"Ranks {sorted(gaps_in_middle)} are not assigned to any module in heterogeneous "
f"deployment. This creates a gap between modules which is not allowed."
)

# Check for leading gap (ranks 0 to min_rank-1 unused) - warning only
if min_rank > 0:
warnings.warn(
f"Ranks {list(range(min_rank))} (before first module) are not assigned to any "
f"module in heterogeneous deployment. These ranks will be idle during training.",
stacklevel=3,
)

def finalize(self, world_size: Optional[int]) -> None:
"""Finalize parallelism config: compute data_parallel_size and validate."""
if "llm" not in self.module_parallelisms:
Expand Down
96 changes: 96 additions & 0 deletions src/megatron/bridge/models/mimo/mimo_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
"""DDP wrapping utilities for MIMO models.

Called from the training layer after MimoModelProvider.provide().

Note: This module only supports DDP wrapping. FSDP is not yet implemented.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional

from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid


if TYPE_CHECKING:
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.models.mimo import MimoModel
from megatron.core.process_groups_config import ProcessGroupCollection

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


def wrap_mimo_model_distributed(
mimo_model: "MimoModel",
ddp_config: "DistributedDataParallelConfig",
mimo_parallelism_config: "MimoParallelismConfig",
grids: Dict[str, "HyperCommGrid"],
pg_collections: Dict[str, Optional["ProcessGroupCollection"]],
) -> "MimoModel":
"""Wrap MIMO model's submodules with DDP.

Modifies mimo_model in-place and returns it.

Args:
mimo_model: The MimoModel to wrap.
ddp_config: DDP configuration from Bridge.
mimo_parallelism_config: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping.
pg_collections: Module name to ProcessGroupCollection mapping.

Returns:
The same mimo_model with wrapped submodules.
"""
from megatron.core.distributed import DistributedDataParallel

# Wrap language model if present and rank participates
if mimo_model.language_model is not None:
llm_grid = grids["llm"]
if is_current_rank_in_grid(llm_grid):
llm_pg = pg_collections.get("llm")
if llm_pg is not None:
mimo_model.language_model = DistributedDataParallel(
config=mimo_model.language_model.config,
ddp_config=ddp_config,
module=mimo_model.language_model,
pg_collection=llm_pg,
)

# Wrap modality submodules
if hasattr(mimo_model, "modality_submodules"):
for module_name, submodule in mimo_model.modality_submodules.items():
if submodule is None:
continue
module_grid = grids[module_name]
if not is_current_rank_in_grid(module_grid):
continue

module_pg = pg_collections.get(module_name)
if module_pg is None:
continue

# Get config from first encoder in the submodule.
# Note: We use the first encoder's config for DDP bucket sizing.
# This assumes all encoders in a modality submodule share similar
# parallelism settings, which is typical for MIMO models.
if hasattr(submodule, "encoders") and submodule.encoders:
encoder_key = next(iter(submodule.encoders.keys()))
first_encoder = submodule.encoders[encoder_key]

if not hasattr(first_encoder, "config"):
raise AttributeError(
f"Encoder '{encoder_key}' in modality '{module_name}' does not have "
f"a 'config' attribute. Encoders must be MegatronModule subclasses."
)

wrapped = DistributedDataParallel(
config=first_encoder.config,
ddp_config=ddp_config,
module=submodule,
pg_collection=module_pg,
)
mimo_model.modality_submodules[module_name] = wrapped

return mimo_model
Loading
Loading