Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/megatron/bridge/models/mimo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from megatron.bridge.models.mimo.llava_provider import LlavaMimoProvider
from megatron.bridge.models.mimo.mimo_config import (
Expand All @@ -12,9 +12,9 @@


__all__ = [
"LlavaMimoProvider",
"MimoModelInfra",
"MimoModelProvider",
"MimoParallelismConfig",
"ModuleParallelismConfig",
"MimoModelProvider",
"MimoModelInfra",
"LlavaMimoProvider",
]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
74 changes: 72 additions & 2 deletions src/megatron/bridge/models/mimo/mimo_builder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import torch.distributed as dist

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


if TYPE_CHECKING:
from megatron.core.hyper_comm_grid import HyperCommGrid


if TYPE_CHECKING:
from megatron.core.process_groups_config import HyperCommGrid

Expand Down Expand Up @@ -56,3 +62,67 @@ def build_hypercomm_grids(
def _default_topology(mimo_parallelism_config: MimoParallelismConfig) -> Dict[str, List[str]]:
"""Infer a default multi-encoder -> LLM topology."""
return {name: ["llm"] for name in mimo_parallelism_config.module_names if name != "llm"} | {"llm": []}


def populate_embedding_and_position_groups(
Comment thread
aroshanghias-nvd marked this conversation as resolved.
Outdated
pp_group: dist.ProcessGroup,
Comment thread
aroshanghias-nvd marked this conversation as resolved.
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
"""Create embedding-related process groups from PP group ranks.

Following MCore semantics:
- pos_embd_pg: Only rank 0 of PP (first stage) - for position embeddings
- embd_pg: Ranks 0 and -1 of PP (first and last stages) - for tied word embeddings

IMPORTANT: This calls dist.new_group which is a collective operation.
Must be called on all ranks that could participate.

Args:
pp_group: The pipeline parallel process group.

Returns:
Tuple of (pos_embd_pg, embd_pg). Returns (None, None) if pp_group is None.
"""
if pp_group is None:
return None, None

pp_ranks = sorted(dist.get_process_group_ranks(pp_group))

# Position embeddings only on first PP stage
pos_embd_ranks = [pp_ranks[0]]
pos_embd_pg = dist.new_group(ranks=pos_embd_ranks)

# Word embeddings on first and last PP stages (for tied embeddings)
embd_ranks = [pp_ranks[0]]
if len(pp_ranks) > 1 and pp_ranks[-1] != pp_ranks[0]:
embd_ranks.append(pp_ranks[-1])
embd_pg = dist.new_group(ranks=embd_ranks)

return pos_embd_pg, embd_pg


def is_pp_first_stage(pp_group: Optional[dist.ProcessGroup]) -> bool:
Comment thread
aroshanghias-nvd marked this conversation as resolved.
Outdated
"""Check if current rank is first stage in pipeline."""
if pp_group is None:
return True
pp_ranks = sorted(dist.get_process_group_ranks(pp_group))
return dist.get_rank() == pp_ranks[0]


def is_pp_last_stage(pp_group: Optional[dist.ProcessGroup]) -> bool:
"""Check if current rank is last stage in pipeline."""
if pp_group is None:
return True
pp_ranks = sorted(dist.get_process_group_ranks(pp_group))
return dist.get_rank() == pp_ranks[-1]


def is_current_rank_in_grid(grid: "HyperCommGrid") -> bool:
"""Check if the current rank participates in this grid.

Args:
grid: A HyperCommGrid instance.

Returns:
True if dist.get_rank() is within the grid's rank range.
"""
return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size)
38 changes: 34 additions & 4 deletions src/megatron/bridge/models/mimo/mimo_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations

Expand Down Expand Up @@ -86,18 +86,48 @@ def total_world_size(self) -> int:
def _validate_heterogeneous(self) -> None:
"""Validate heterogeneous deployment: no overlapping rank ranges."""
ranges = []
for parallelism in self.module_parallelisms.values():
for name, parallelism in self.module_parallelisms.items():
if parallelism.data_parallel_size is None:
raise ValueError("data_parallel_size must be set for heterogeneous deployment.")
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks))
ranges.append((parallelism.rank_offset, parallelism.rank_offset + parallelism.total_ranks, name))

ranges.sort()
ranges.sort(key=lambda x: x[0])
for idx in range(1, len(ranges)):
prev_end = ranges[idx - 1][1]
cur_start = ranges[idx][0]
if cur_start < prev_end:
raise ValueError("rank_offset ranges overlap in heterogeneous deployment.")

# Check for gaps between modules (likely misconfiguration)
# Gaps in the middle are errors; leading gaps (rank_offset > 0) are warnings
if ranges:
min_rank = ranges[0][0] # Already sorted by rank_offset
max_rank = ranges[-1][1]

# Collect all covered ranks
covered_ranks = set()
for parallelism in self.module_parallelisms.values():
start = parallelism.rank_offset
end = start + parallelism.total_ranks
covered_ranks.update(range(start, end))

# Check for gaps between min and max (error - likely misconfiguration)
expected_middle = set(range(min_rank, max_rank))
gaps_in_middle = expected_middle - covered_ranks
if gaps_in_middle:
raise ValueError(
f"Ranks {sorted(gaps_in_middle)} are not assigned to any module in heterogeneous "
f"deployment. This creates a gap between modules which is not allowed."
)

# Check for leading gap (ranks 0 to min_rank-1 unused) - warning only
if min_rank > 0:
warnings.warn(
f"Ranks {list(range(min_rank))} (before first module) are not assigned to any "
f"module in heterogeneous deployment. These ranks will be idle during training.",
stacklevel=3,
)

def finalize(self, world_size: Optional[int]) -> None:
"""Finalize parallelism config: compute data_parallel_size and validate."""
if "llm" not in self.module_parallelisms:
Expand Down
98 changes: 98 additions & 0 deletions src/megatron/bridge/models/mimo/mimo_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
"""DDP wrapping utilities for MIMO models.

Called from the training layer after MimoModelProvider.provide().

Note: This module only supports DDP wrapping. FSDP is not yet implemented.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional

from megatron.bridge.models.mimo.mimo_builder import is_current_rank_in_grid


if TYPE_CHECKING:
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.models.mimo import MimoModel
from megatron.core.process_groups_config import ProcessGroupCollection

from megatron.bridge.models.mimo.mimo_config import MimoParallelismConfig


def wrap_mimo_model_distributed(
Comment thread
aroshanghias-nvd marked this conversation as resolved.
mimo_model: "MimoModel",
ddp_config: "DistributedDataParallelConfig",
mimo_parallelism_config: "MimoParallelismConfig",
grids: Dict[str, "HyperCommGrid"],
pg_collections: Dict[str, Optional["ProcessGroupCollection"]],
) -> "MimoModel":
"""Wrap MIMO model's submodules with DDP.

Modifies mimo_model in-place and returns it.

Args:
mimo_model: The MimoModel to wrap.
ddp_config: DDP configuration from Bridge.
mimo_parallelism_config: MIMO parallelism configuration.
grids: Module name to HyperCommGrid mapping.
pg_collections: Module name to ProcessGroupCollection mapping.

Returns:
The same mimo_model with wrapped submodules.
"""
from megatron.core.distributed import DistributedDataParallel

# Wrap language model if present and rank participates
if mimo_model.language_model is not None:
llm_grid = grids.get("llm")
Comment thread
aroshanghias-nvd marked this conversation as resolved.
Outdated
if llm_grid is not None and is_current_rank_in_grid(llm_grid):
llm_pg = pg_collections.get("llm")
if llm_pg is not None:
mimo_model.language_model = DistributedDataParallel(
config=mimo_model.language_model.config,
ddp_config=ddp_config,
module=mimo_model.language_model,
pg_collection=llm_pg,
)

# Wrap modality submodules
if hasattr(mimo_model, "modality_submodules"):
for module_name, submodule in mimo_model.modality_submodules.items():
if submodule is None:
continue
module_grid = grids.get(module_name)
if module_grid is None:
Comment thread
aroshanghias-nvd marked this conversation as resolved.
Outdated
continue
if not is_current_rank_in_grid(module_grid):
continue

module_pg = pg_collections.get(module_name)
if module_pg is None:
continue

# Get config from first encoder in the submodule.
# Note: We use the first encoder's config for DDP bucket sizing.
# This assumes all encoders in a modality submodule share similar
# parallelism settings, which is typical for MIMO models.
if hasattr(submodule, "encoders") and submodule.encoders:
encoder_key = next(iter(submodule.encoders.keys()))
first_encoder = submodule.encoders[encoder_key]

if not hasattr(first_encoder, "config"):
raise AttributeError(
f"Encoder '{encoder_key}' in modality '{module_name}' does not have "
f"a 'config' attribute. Encoders must be MegatronModule subclasses."
)

wrapped = DistributedDataParallel(
config=first_encoder.config,
ddp_config=ddp_config,
module=submodule,
pg_collection=module_pg,
)
mimo_model.modality_submodules[module_name] = wrapped

return mimo_model
Loading
Loading