Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c601de4
add pp stage checkers to p2p communicator
yashaswikarnati Jan 27, 2026
84ae4f0
add process group collection wrapper
yashaswikarnati Jan 27, 2026
0fa3dd8
support multimodule pipelining in 1f1b schedule
yashaswikarnati Jan 27, 2026
b22f638
fix dim mapping in torch cat bridge comm
yashaswikarnati Jan 28, 2026
3badf57
handle 3d 2d tensor conversion in multimodule comm
yashaswikarnati Jan 28, 2026
20d03f5
add unit tests for multimodule pipeline schedules
yashaswikarnati Jan 28, 2026
a6606d8
refactor multimodule pg collection and backward step
yashaswikarnati Jan 28, 2026
b102eb7
rename module_collections to module_pgs for clarity
yashaswikarnati Jan 28, 2026
ebbb509
rename tensor conversion functions for clarity
yashaswikarnati Jan 28, 2026
2d7c176
Merge branch 'main' into yash/1f1b_changes
dimapihtar Jan 29, 2026
0b6cefd
Fix linting issues: format code and remove unused imports
yashaswikarnati Feb 3, 2026
7d566d9
Add RankRole and ModuleStageInfo for multi-module pipeline parallelism
yashaswikarnati Feb 2, 2026
997dfa5
Add stage-aware forward pass to modality submodules
yashaswikarnati Feb 2, 2026
a1a8fdc
Update MimoModel for multi-module pipeline parallelism
yashaswikarnati Feb 2, 2026
b46a157
Add unit tests for multi-module pipeline parallelism
yashaswikarnati Feb 2, 2026
7da19e1
Add .worktrees/ to gitignore
yashaswikarnati Feb 2, 2026
5b94c0f
Merge branch 'main' into yash/mimo_non_colocated
dimapihtar Feb 4, 2026
84a1ebf
Merge upstream main into yash/mimo_non_colocated
yashaswikarnati Mar 19, 2026
62dfb89
Simplify MIMO model and consolidate submodule logic
yashaswikarnati Mar 19, 2026
99800d2
Simplify and deduplicate MIMO model tests
yashaswikarnati Mar 19, 2026
bf75c2d
Make RankRole always-present and simplify encoder dispatch
yashaswikarnati Mar 19, 2026
4bb9711
Replace configurable language_module_key with fixed LANGUAGE_MODULE_KEY
yashaswikarnati Mar 19, 2026
3e97b32
Improve MIMO pipeline abstraction and type safety
yashaswikarnati Mar 19, 2026
03207bc
Rename PipelineMode to ModuleLayout
yashaswikarnati Mar 19, 2026
22fbeba
Fix non-colocated forward pass and rewrite MIMO 1F1B schedule tests
yashaswikarnati Mar 20, 2026
bd1b01b
Merge remote-tracking branch 'origin/main' into yash/mimo_non_colocated
yashaswikarnati Mar 21, 2026
1335de6
Fix stale MiMo tests for multi-rank execution
yashaswikarnati Mar 22, 2026
a94098d
Fix copyright year to 2026 in new files
yashaswikarnati Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ runs/

# Sphinx documentation
docs/_build
docs/apidocs
docs/apidocs

# Git worktrees
.worktrees/
3 changes: 2 additions & 1 deletion megatron/core/models/mimo/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

from megatron.core.models.mimo.config.base_configs import MimoModelConfig
from megatron.core.models.mimo.config.role import ModuleStageInfo, RankRole

__all__ = ['MimoModelConfig']
__all__ = ['MimoModelConfig', 'ModuleStageInfo', 'RankRole']
9 changes: 8 additions & 1 deletion megatron/core/models/mimo/config/base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import warnings
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, Optional

from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.transformer.spec_utils import ModuleSpec


Expand All @@ -20,6 +21,11 @@ class MimoModelConfig:
Dictionary mapping modality names to their special token IDs.
For example, {"vision": -200, "audio":32000}, these represent placeholders
in the input_ids to insert the modality embeddings at the correct positions.
module_to_grid_map (Optional[Dict[str, HyperCommGrid]]):
Dictionary mapping module keys (e.g., "vision", "language") to their
corresponding HyperCommGrid configurations for non-colocated pipeline
parallelism. The language model must use the key MIMO_LANGUAGE_MODULE_KEY.
When None, all modules are assumed to be colocated on the same ranks.
kv_format (str):
Key-value format for attention: "sbhd" (seq-batch-head-dim) or "thd" (total-head-dim).
Default is "sbhd".
Expand All @@ -35,4 +41,5 @@ class MimoModelConfig:
language_model_spec: ModuleSpec = field(default_factory=ModuleSpec)
modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict)
special_token_ids: Dict[str, int] = field(default_factory=dict)
module_to_grid_map: Optional[Dict[str, HyperCommGrid]] = None
kv_format: str = "sbhd"
172 changes: 172 additions & 0 deletions megatron/core/models/mimo/config/role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""Data classes for MIMO rank role management in multi-module pipeline parallelism."""

import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List

import torch.distributed as dist

from megatron.core.hyper_comm_grid import HyperCommGrid

logger = logging.getLogger(__name__)

# Fixed key for the language module in module_to_grid_map and RankRole.
# MIMO always has exactly one language model, so this is not configurable.
MIMO_LANGUAGE_MODULE_KEY = "language"


class ModuleLayout(Enum):
"""Pipeline mode for MIMO multi-module parallelism.

Determines how modules are distributed across ranks and which
forward path is used.

UNIFIED: No module_to_grid_map. All modules share same ranks and
parallelism. Uses the unified forward path (_forward_all_modules).

NON_COLOCATED: module_to_grid_map is set with non-overlapping rank
ranges. Each rank runs EITHER encoder(s) OR the language model.
Uses role-based dispatch with separate forward paths.

COLOCATED: (future) module_to_grid_map is set with overlapping rank
ranges. Encoder(s) and language model share ranks but have
different parallelism configs. Uses role-based dispatch but
allows both module types on the same rank.
"""

UNIFIED = "unified"
NON_COLOCATED = "non_colocated"
COLOCATED = "colocated"


@dataclass
class ModuleStageInfo:
"""Information about a rank's stage position within a module's pipeline.

Args:
is_first_stage: True if this rank is the first PP stage for this module.
is_last_stage: True if this rank is the last PP stage for this module.
"""

is_first_stage: bool
is_last_stage: bool


@dataclass
class RankRole:
"""Describes what modules this rank participates in for multi-module PP.

This class captures the role of a specific rank in a multi-module pipeline
parallel setup, tracking which modules the rank participates in and their
stage positions. The language module is always identified by MIMO_LANGUAGE_MODULE_KEY.

Args:
modules: Dict mapping module names to their stage info for modules
this rank participates in.
mode: Pipeline mode determining forward path dispatch.
"""

modules: Dict[str, ModuleStageInfo] = field(default_factory=dict)
mode: ModuleLayout = ModuleLayout.UNIFIED

@classmethod
def unified(cls, module_names: List[str]) -> 'RankRole':
"""Create a role for the unified case: every module, first+last stage."""
return cls(
modules={
name: ModuleStageInfo(is_first_stage=True, is_last_stage=True)
for name in module_names
},
mode=ModuleLayout.UNIFIED,
)

@classmethod
def from_grid_map(
cls, module_to_grid_map: Dict[str, HyperCommGrid], modality_module_names: List[str]
) -> 'RankRole':
"""Create a role from a module-to-grid mapping for non-colocated PP.

Determines which modules the current rank participates in and its
pipeline stage position within each module.

Args:
module_to_grid_map: Dict mapping module names to HyperCommGrid objects.
Must contain keys matching modality_module_names + MIMO_LANGUAGE_MODULE_KEY.
modality_module_names: List of modality module names (e.g., ["images", "audio"]).

Returns:
RankRole for the current rank.

Raises:
ValueError: If grid map keys don't match expected module names.
RuntimeError: If current rank is not in any module grid.
"""
# Validate keys
expected_keys = set(modality_module_names) | {MIMO_LANGUAGE_MODULE_KEY}
grid_keys = set(module_to_grid_map.keys())
if grid_keys != expected_keys:
raise ValueError(
f"module_to_grid_map keys must match modality module names + "
f"'{MIMO_LANGUAGE_MODULE_KEY}'. Missing: {expected_keys - grid_keys}, "
f"Extra: {grid_keys - expected_keys}"
)

current_rank = dist.get_rank()
modules = {}

for module_name, grid in module_to_grid_map.items():
if not (grid.rank_offset <= current_rank < grid.rank_offset + grid.size):
continue

if "pp" not in grid.dim_names:
modules[module_name] = ModuleStageInfo(is_first_stage=True, is_last_stage=True)
continue

pp_group = grid.get_pg("pp")
pp_rank = pp_group.rank()
pp_size = pp_group.size()
is_first = pp_rank == 0
is_last = pp_rank == pp_size - 1
logger.info(
f"[RankRole.from_grid_map] Rank {current_rank}: module={module_name}, "
f"pp_rank={pp_rank}/{pp_size}, is_first_stage={is_first}, is_last_stage={is_last}"
)
modules[module_name] = ModuleStageInfo(is_first_stage=is_first, is_last_stage=is_last)

if not modules:
raise RuntimeError(
f"Rank {current_rank} is not in any module grid. "
f"Check module_to_grid_map configuration."
)

return cls(modules=modules, mode=ModuleLayout.NON_COLOCATED)

@property
def has_modality_modules(self) -> bool:
"""Return True if this rank participates in any modality (non-language) module."""
return any(name != MIMO_LANGUAGE_MODULE_KEY for name in self.modules)

@property
def has_language_module(self) -> bool:
"""Return True if this rank participates in the language module."""
return MIMO_LANGUAGE_MODULE_KEY in self.modules

@property
def modality_module_names(self) -> List[str]:
"""Return names of modality modules (non-language) this rank participates in."""
return [name for name in self.modules if name != MIMO_LANGUAGE_MODULE_KEY]

def is_first_stage(self, module_name: str) -> bool:
"""Check if this rank is the first stage for a given module."""
if module_name not in self.modules:
return False
return self.modules[module_name].is_first_stage

def is_last_stage(self, module_name: str) -> bool:
"""Check if this rank is the last stage for a given module."""
if module_name not in self.modules:
return False
return self.modules[module_name].is_last_stage
Loading
Loading