Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/megatron/bridge/models/mimo/mimo_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def build_hypercomm_grids(
_ = grid.create_pg(["tp", "pp"])
_ = grid.create_pg(["tp", "ep", "pp"])
_ = grid.create_pg(["dp", "ep"])
_ = grid.create_pg(["tp", "cp", "ep", "pp", "dp"])

grids[module_name] = grid

Expand Down
3 changes: 2 additions & 1 deletion src/megatron/bridge/models/mimo/mimo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]):
mimo_parallelism_config: Optional[MimoParallelismConfig] = None

# Module data-flow DAG for MultiModulePipelineCommunicator.
# If None, auto-derived as: all modality_submodules → language module (terminal).
# If None, auto-derived as: all modality_submodules → MIMO_LANGUAGE_MODULE_KEY (terminal).
# Set explicitly for non-standard topologies (e.g., language → generator).
topology: Optional[Dict[str, List[str]]] = None

Expand Down Expand Up @@ -167,6 +167,7 @@ def build_infra(self) -> MimoModelInfra:
participating_modules = [name for name, pg in pg_collections.items() if pg is not None]

# Derive module output tensor dimensionality if not explicitly configured.
# Language module produces 3D [S, B, H]; modality encoders produce 2D [S, H].
if self.module_output_ndim is not None:
output_ndim = self.module_output_ndim
else:
Expand Down
5 changes: 2 additions & 3 deletions src/megatron/bridge/training/mimo_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def forward_step(
needs_data = True
if mimo_model.role is not None:
if mimo_model.role.has_language_module:
module_name = MIMO_LANGUAGE_MODULE_KEY
is_first_stage = mimo_model.role.is_first_stage(module_name)
is_last_stage = mimo_model.role.is_last_stage(module_name)
is_first_stage = mimo_model.role.is_first_stage(MIMO_LANGUAGE_MODULE_KEY)
is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY)
needs_data = is_first_stage or is_last_stage
elif mimo_model.role.has_modality_modules:
modality_modules = mimo_model.role.modality_module_names
Expand Down
15 changes: 7 additions & 8 deletions src/megatron/bridge/training/pretrain_mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional

import torch
import torch.distributed as dist
from megatron.core.models.mimo import MimoModel
from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
from megatron.core.utils import get_model_config

from megatron.bridge.training.checkpointing import init_checkpointing_context, load_checkpoint
from megatron.bridge.training.utils.checkpoint_utils import checkpoint_exists
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.mimo_parallel_utils import (
build_pg_collection_for_schedule,
Expand All @@ -33,6 +31,7 @@
)
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.train_mimo import train_mimo
from megatron.bridge.training.utils.checkpoint_utils import checkpoint_exists


if TYPE_CHECKING:
Expand Down Expand Up @@ -256,8 +255,7 @@ def pretrain_mimo(

rampup_batch_size = getattr(cfg.train, "rampup_batch_size", None)
assert rampup_batch_size is None, (
"Microbatch rampup is not supported in MiMo training. "
"Set rampup_batch_size to None."
"Microbatch rampup is not supported in MiMo training. Set rampup_batch_size to None."
)

if nmc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
Expand Down Expand Up @@ -350,9 +348,8 @@ def pretrain_mimo(

# Broadened load-intent gating: includes non-persistent resume intent
has_persistent = cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load)
has_pretrained = (
cfg.checkpoint.pretrained_checkpoint is not None
and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint)
has_pretrained = cfg.checkpoint.pretrained_checkpoint is not None and checkpoint_exists(
cfg.checkpoint.pretrained_checkpoint
)
wants_non_persistent = cfg.checkpoint.non_persistent_ckpt_type is not None
should_load = has_persistent or has_pretrained or wants_non_persistent
Expand Down Expand Up @@ -391,7 +388,9 @@ def pretrain_mimo(
sig = inspect.signature(build_data_iterators_fn)
if "train_state" in sig.parameters:
train_data_iterator, valid_data_iterator = build_data_iterators_fn(
cfg, setup_output.mimo_infra, train_state=train_state,
cfg,
setup_output.mimo_infra,
train_state=train_state,
)
else:
raise RuntimeError(
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/training/train_mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from megatron.core.utils import get_model_config

from megatron.bridge.training.checkpointing import maybe_finalize_async_save
from megatron.bridge.training.train import checkpoint_and_decide_exit, save_checkpoint_and_time
from megatron.bridge.training.eval import evaluate_and_print_results
from megatron.bridge.training.mimo_parallel_utils import (
build_pg_collection_for_schedule,
Expand All @@ -46,6 +45,7 @@
should_profile_rank,
)
from megatron.bridge.training.state import GlobalState
from megatron.bridge.training.train import checkpoint_and_decide_exit
from megatron.bridge.training.utils.train_utils import (
prepare_forward_step_func,
training_log,
Expand Down
97 changes: 68 additions & 29 deletions tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from megatron.bridge.training.state import GlobalState, TrainState
from megatron.bridge.training.tokenizers.config import TokenizerConfig


logger = logging.getLogger(__name__)

SAVE_STEPS = 5
Expand All @@ -70,9 +71,15 @@

def _make_vision_config() -> TransformerConfig:
cfg = TransformerConfig(
num_layers=2, hidden_size=64, ffn_hidden_size=256, num_attention_heads=4,
use_cpu_initialization=True, pipeline_dtype=torch.bfloat16, bf16=True,
variable_seq_lengths=True, moe_token_dispatcher_type="alltoall",
num_layers=2,
hidden_size=64,
ffn_hidden_size=256,
num_attention_heads=4,
use_cpu_initialization=True,
pipeline_dtype=torch.bfloat16,
bf16=True,
variable_seq_lengths=True,
moe_token_dispatcher_type="alltoall",
)
cfg.add_bias_linear = True
cfg.add_qkv_bias = True
Expand All @@ -91,9 +98,15 @@ def _make_vision_config() -> TransformerConfig:

def _make_language_config() -> TransformerConfig:
return TransformerConfig(
num_layers=2, hidden_size=64, ffn_hidden_size=256, num_attention_heads=4,
use_cpu_initialization=True, pipeline_dtype=torch.bfloat16, bf16=True,
variable_seq_lengths=True, moe_token_dispatcher_type="alltoall",
num_layers=2,
hidden_size=64,
ffn_hidden_size=256,
num_attention_heads=4,
use_cpu_initialization=True,
pipeline_dtype=torch.bfloat16,
bf16=True,
variable_seq_lengths=True,
moe_token_dispatcher_type="alltoall",
cross_entropy_loss_fusion=True,
)

Expand All @@ -107,19 +120,23 @@ def _build_model_specs():
params={
"transformer_config": vision_config,
"transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(),
"patch_dim": _PATCH_DIM, "img_h": _IMG_SIZE, "img_w": _IMG_SIZE,
"patch_dim": _PATCH_DIM,
"img_h": _IMG_SIZE,
"img_w": _IMG_SIZE,
},
)
vision_submodule_spec = ModuleSpec(
module=VisionModalitySubmodules, params={},
module=VisionModalitySubmodules,
params={},
submodules={"encoders": {"clip": vision_encoder}},
)
language_model_spec = ModuleSpec(
module=GPTModel,
params={
"config": language_config,
"transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(),
"vocab_size": _VOCAB_SIZE, "max_sequence_length": _SEQ_LENGTH,
"vocab_size": _VOCAB_SIZE,
"max_sequence_length": _SEQ_LENGTH,
},
)
return language_model_spec, {"vision": vision_submodule_spec}, {"vision": _SPECIAL_TOKEN_ID}
Expand All @@ -134,7 +151,7 @@ def _build_parallelism_config() -> MimoParallelismConfig:
"""
return MimoParallelismConfig(
module_parallelisms={
"llm": ModuleParallelismConfig(
"language": ModuleParallelismConfig(
tensor_model_parallel_size=int(os.environ.get("MIMO_LLM_TP", "4")),
pipeline_model_parallel_size=int(os.environ.get("MIMO_LLM_PP", "1")),
data_parallel_size=int(os.environ.get("MIMO_LLM_DP", "1")),
Expand Down Expand Up @@ -234,7 +251,9 @@ def _build_config(
max_dp = max(p.data_parallel_size for p in par_cfg.module_parallelisms.values())

train_cfg = TrainingConfig(
micro_batch_size=1, global_batch_size=max_dp, train_iters=train_iters,
micro_batch_size=1,
global_batch_size=max_dp,
train_iters=train_iters,
)
train_cfg.num_microbatches = 1
train_cfg.grad_reduce_in_fp32 = False
Expand All @@ -246,7 +265,7 @@ def _build_config(
logger_cfg = LoggerConfig()
logger_cfg.log_interval = 1

llm_pp = par_cfg.module_parallelisms["llm"].pipeline_model_parallel_size
llm_pp = par_cfg.module_parallelisms["language"].pipeline_model_parallel_size
ckpt_cfg = CheckpointConfig(
save_interval=save_interval,
save=ckpt_dir,
Expand Down Expand Up @@ -295,7 +314,7 @@ def _run_phase_save(ckpt_dir: str) -> None:
modality_submodules_spec=modality_specs,
special_token_ids=special_tokens,
mimo_parallelism_config=_build_parallelism_config(),
topology={"vision": ["llm"], "llm": []},
topology={"vision": ["language"], "language": []},
use_cpu_initialization=True,
)
if not hasattr(mimo_provider, "num_moe_experts"):
Expand All @@ -306,22 +325,33 @@ def _run_phase_save(ckpt_dir: str) -> None:
mock_data = _build_mock_data_provider()
bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True)
mcore_opt = MCoreOptimizerConfig(
optimizer="adam", lr=1e-4, min_lr=0.0, weight_decay=0.01,
clip_grad=1.0, bf16=True, use_distributed_optimizer=True,
optimizer="adam",
lr=1e-4,
min_lr=0.0,
weight_decay=0.01,
clip_grad=1.0,
bf16=True,
use_distributed_optimizer=True,
)

cfg = _build_config(
mimo_provider, mock_data, bridge_opt, ckpt_dir,
train_iters=SAVE_STEPS, save_interval=SAVE_STEPS,
mimo_provider,
mock_data,
bridge_opt,
ckpt_dir,
train_iters=SAVE_STEPS,
save_interval=SAVE_STEPS,
)

global_state = GlobalState()

pretrain_mimo(
cfg=cfg, mimo_provider=mimo_provider,
cfg=cfg,
mimo_provider=mimo_provider,
forward_step_func=mimo_forward_step,
build_data_iterators_fn=_build_data_iterators,
opt_config=mcore_opt, schedulers={},
opt_config=mcore_opt,
schedulers={},
global_state=global_state,
)

Expand Down Expand Up @@ -360,7 +390,7 @@ def _run_phase_resume(ckpt_dir: str) -> None:
modality_submodules_spec=modality_specs,
special_token_ids=special_tokens,
mimo_parallelism_config=_build_parallelism_config(),
topology={"vision": ["llm"], "llm": []},
topology={"vision": ["language"], "language": []},
use_cpu_initialization=True,
)
if not hasattr(mimo_provider, "num_moe_experts"):
Expand All @@ -371,13 +401,22 @@ def _run_phase_resume(ckpt_dir: str) -> None:
mock_data = _build_mock_data_provider()
bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True)
mcore_opt = MCoreOptimizerConfig(
optimizer="adam", lr=1e-4, min_lr=0.0, weight_decay=0.01,
clip_grad=1.0, bf16=True, use_distributed_optimizer=True,
optimizer="adam",
lr=1e-4,
min_lr=0.0,
weight_decay=0.01,
clip_grad=1.0,
bf16=True,
use_distributed_optimizer=True,
)

cfg = _build_config(
mimo_provider, mock_data, bridge_opt, ckpt_dir,
train_iters=TOTAL_STEPS, save_interval=TOTAL_STEPS,
mimo_provider,
mock_data,
bridge_opt,
ckpt_dir,
train_iters=TOTAL_STEPS,
save_interval=TOTAL_STEPS,
load_dir=ckpt_dir,
)
# Save phase used train_iters=SAVE_STEPS, so checkpoint scheduler state
Expand All @@ -390,10 +429,12 @@ def _run_phase_resume(ckpt_dir: str) -> None:
global_state = GlobalState()

pretrain_mimo(
cfg=cfg, mimo_provider=mimo_provider,
cfg=cfg,
mimo_provider=mimo_provider,
forward_step_func=mimo_forward_step,
build_data_iterators_fn=_build_data_iterators,
opt_config=mcore_opt, schedulers={},
opt_config=mcore_opt,
schedulers={},
global_state=global_state,
)

Expand All @@ -402,9 +443,7 @@ def _run_phase_resume(ckpt_dir: str) -> None:
_log(f"Phase RESUME complete: step={ts.step}, consumed_train_samples={ts.consumed_train_samples}")

# Verify step continuity
assert ts.step == TOTAL_STEPS, (
f"Step continuity failed: expected {TOTAL_STEPS}, got {ts.step}"
)
assert ts.step == TOTAL_STEPS, f"Step continuity failed: expected {TOTAL_STEPS}, got {ts.step}"

# Verify consumed_train_samples did not reset to 0
assert ts.consumed_train_samples >= saved_marker["consumed_train_samples"], (
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/mimo/test_mimo_training_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _build_model_specs():
def _build_parallelism_config() -> MimoParallelismConfig:
return MimoParallelismConfig(
module_parallelisms={
"llm": ModuleParallelismConfig(
"language": ModuleParallelismConfig(
tensor_model_parallel_size=4,
pipeline_model_parallel_size=1,
data_parallel_size=1,
Expand Down Expand Up @@ -332,7 +332,7 @@ def main():
modality_submodules_spec=modality_submodules_spec,
special_token_ids=special_token_ids,
mimo_parallelism_config=mimo_parallelism_config,
topology={"vision": ["llm"], "llm": []},
topology={"vision": ["language"], "language": []},
use_cpu_initialization=True,
)
if not hasattr(mimo_provider, "num_moe_experts"):
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/mimo/test_mimo_training_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _build_model_specs():
def _build_parallelism_config() -> MimoParallelismConfig:
return MimoParallelismConfig(
module_parallelisms={
"llm": ModuleParallelismConfig(
"language": ModuleParallelismConfig(
tensor_model_parallel_size=4,
pipeline_model_parallel_size=1,
data_parallel_size=1,
Expand Down Expand Up @@ -493,7 +493,7 @@ def main():
modality_submodules_spec=modality_submodules_spec,
special_token_ids=special_token_ids,
mimo_parallelism_config=mimo_parallelism_config,
topology={"images": ["llm"], "llm": []},
topology={"images": ["language"], "language": []},
use_cpu_initialization=True,
bf16=True,
)
Expand Down
Loading
Loading