diff --git a/src/megatron/bridge/models/mimo/mimo_builder.py b/src/megatron/bridge/models/mimo/mimo_builder.py index a8b266a2f4..c22fbdf0b6 100644 --- a/src/megatron/bridge/models/mimo/mimo_builder.py +++ b/src/megatron/bridge/models/mimo/mimo_builder.py @@ -50,6 +50,7 @@ def build_hypercomm_grids( _ = grid.create_pg(["tp", "pp"]) _ = grid.create_pg(["tp", "ep", "pp"]) _ = grid.create_pg(["dp", "ep"]) + _ = grid.create_pg(["tp", "cp", "ep", "pp", "dp"]) grids[module_name] = grid diff --git a/src/megatron/bridge/models/mimo/mimo_provider.py b/src/megatron/bridge/models/mimo/mimo_provider.py index 2cddb904ec..4d8438e881 100644 --- a/src/megatron/bridge/models/mimo/mimo_provider.py +++ b/src/megatron/bridge/models/mimo/mimo_provider.py @@ -110,7 +110,7 @@ class MimoModelProvider(ModelProviderMixin[MimoModel]): mimo_parallelism_config: Optional[MimoParallelismConfig] = None # Module data-flow DAG for MultiModulePipelineCommunicator. - # If None, auto-derived as: all modality_submodules → language module (terminal). + # If None, auto-derived as: all modality_submodules → MIMO_LANGUAGE_MODULE_KEY (terminal). # Set explicitly for non-standard topologies (e.g., language → generator). topology: Optional[Dict[str, List[str]]] = None @@ -167,6 +167,7 @@ def build_infra(self) -> MimoModelInfra: participating_modules = [name for name, pg in pg_collections.items() if pg is not None] # Derive module output tensor dimensionality if not explicitly configured. + # Language module produces 3D [S, B, H]; modality encoders produce 2D [S, H]. if self.module_output_ndim is not None: output_ndim = self.module_output_ndim else: diff --git a/src/megatron/bridge/training/mimo_step.py b/src/megatron/bridge/training/mimo_step.py index d4a2d0ebf6..2da7fd09ec 100644 --- a/src/megatron/bridge/training/mimo_step.py +++ b/src/megatron/bridge/training/mimo_step.py @@ -157,9 +157,8 @@ def forward_step( needs_data = True if mimo_model.role is not None: if mimo_model.role.has_language_module: - module_name = MIMO_LANGUAGE_MODULE_KEY - is_first_stage = mimo_model.role.is_first_stage(module_name) - is_last_stage = mimo_model.role.is_last_stage(module_name) + is_first_stage = mimo_model.role.is_first_stage(MIMO_LANGUAGE_MODULE_KEY) + is_last_stage = mimo_model.role.is_last_stage(MIMO_LANGUAGE_MODULE_KEY) needs_data = is_first_stage or is_last_stage elif mimo_model.role.has_modality_modules: modality_modules = mimo_model.role.modality_module_names diff --git a/src/megatron/bridge/training/pretrain_mimo.py b/src/megatron/bridge/training/pretrain_mimo.py index 66434535b1..e659260239 100644 --- a/src/megatron/bridge/training/pretrain_mimo.py +++ b/src/megatron/bridge/training/pretrain_mimo.py @@ -16,14 +16,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional -import torch import torch.distributed as dist from megatron.core.models.mimo import MimoModel from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.utils import get_model_config from megatron.bridge.training.checkpointing import init_checkpointing_context, load_checkpoint -from megatron.bridge.training.utils.checkpoint_utils import checkpoint_exists from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.mimo_parallel_utils import ( build_pg_collection_for_schedule, @@ -33,6 +31,7 @@ ) from megatron.bridge.training.state import GlobalState from megatron.bridge.training.train_mimo import train_mimo +from megatron.bridge.training.utils.checkpoint_utils import checkpoint_exists if TYPE_CHECKING: @@ -256,8 +255,7 @@ def pretrain_mimo( rampup_batch_size = getattr(cfg.train, "rampup_batch_size", None) assert rampup_batch_size is None, ( - "Microbatch rampup is not supported in MiMo training. " - "Set rampup_batch_size to None." + "Microbatch rampup is not supported in MiMo training. Set rampup_batch_size to None." ) if nmc._GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: @@ -350,9 +348,8 @@ def pretrain_mimo( # Broadened load-intent gating: includes non-persistent resume intent has_persistent = cfg.checkpoint.load is not None and checkpoint_exists(cfg.checkpoint.load) - has_pretrained = ( - cfg.checkpoint.pretrained_checkpoint is not None - and checkpoint_exists(cfg.checkpoint.pretrained_checkpoint) + has_pretrained = cfg.checkpoint.pretrained_checkpoint is not None and checkpoint_exists( + cfg.checkpoint.pretrained_checkpoint ) wants_non_persistent = cfg.checkpoint.non_persistent_ckpt_type is not None should_load = has_persistent or has_pretrained or wants_non_persistent @@ -391,7 +388,9 @@ def pretrain_mimo( sig = inspect.signature(build_data_iterators_fn) if "train_state" in sig.parameters: train_data_iterator, valid_data_iterator = build_data_iterators_fn( - cfg, setup_output.mimo_infra, train_state=train_state, + cfg, + setup_output.mimo_infra, + train_state=train_state, ) else: raise RuntimeError( diff --git a/src/megatron/bridge/training/train_mimo.py b/src/megatron/bridge/training/train_mimo.py index 2dfc6ebb3b..aa020d065d 100644 --- a/src/megatron/bridge/training/train_mimo.py +++ b/src/megatron/bridge/training/train_mimo.py @@ -29,7 +29,6 @@ from megatron.core.utils import get_model_config from megatron.bridge.training.checkpointing import maybe_finalize_async_save -from megatron.bridge.training.train import checkpoint_and_decide_exit, save_checkpoint_and_time from megatron.bridge.training.eval import evaluate_and_print_results from megatron.bridge.training.mimo_parallel_utils import ( build_pg_collection_for_schedule, @@ -46,6 +45,7 @@ should_profile_rank, ) from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.train import checkpoint_and_decide_exit from megatron.bridge.training.utils.train_utils import ( prepare_forward_step_func, training_log, diff --git a/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py b/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py index 7d4e57b97d..3e093e52ca 100644 --- a/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py +++ b/tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py @@ -51,6 +51,7 @@ from megatron.bridge.training.state import GlobalState, TrainState from megatron.bridge.training.tokenizers.config import TokenizerConfig + logger = logging.getLogger(__name__) SAVE_STEPS = 5 @@ -70,9 +71,15 @@ def _make_vision_config() -> TransformerConfig: cfg = TransformerConfig( - num_layers=2, hidden_size=64, ffn_hidden_size=256, num_attention_heads=4, - use_cpu_initialization=True, pipeline_dtype=torch.bfloat16, bf16=True, - variable_seq_lengths=True, moe_token_dispatcher_type="alltoall", + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", ) cfg.add_bias_linear = True cfg.add_qkv_bias = True @@ -91,9 +98,15 @@ def _make_vision_config() -> TransformerConfig: def _make_language_config() -> TransformerConfig: return TransformerConfig( - num_layers=2, hidden_size=64, ffn_hidden_size=256, num_attention_heads=4, - use_cpu_initialization=True, pipeline_dtype=torch.bfloat16, bf16=True, - variable_seq_lengths=True, moe_token_dispatcher_type="alltoall", + num_layers=2, + hidden_size=64, + ffn_hidden_size=256, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + variable_seq_lengths=True, + moe_token_dispatcher_type="alltoall", cross_entropy_loss_fusion=True, ) @@ -107,11 +120,14 @@ def _build_model_specs(): params={ "transformer_config": vision_config, "transformer_layer_spec": get_vit_layer_with_transformer_engine_spec(), - "patch_dim": _PATCH_DIM, "img_h": _IMG_SIZE, "img_w": _IMG_SIZE, + "patch_dim": _PATCH_DIM, + "img_h": _IMG_SIZE, + "img_w": _IMG_SIZE, }, ) vision_submodule_spec = ModuleSpec( - module=VisionModalitySubmodules, params={}, + module=VisionModalitySubmodules, + params={}, submodules={"encoders": {"clip": vision_encoder}}, ) language_model_spec = ModuleSpec( @@ -119,7 +135,8 @@ def _build_model_specs(): params={ "config": language_config, "transformer_layer_spec": get_gpt_layer_with_transformer_engine_spec(), - "vocab_size": _VOCAB_SIZE, "max_sequence_length": _SEQ_LENGTH, + "vocab_size": _VOCAB_SIZE, + "max_sequence_length": _SEQ_LENGTH, }, ) return language_model_spec, {"vision": vision_submodule_spec}, {"vision": _SPECIAL_TOKEN_ID} @@ -134,7 +151,7 @@ def _build_parallelism_config() -> MimoParallelismConfig: """ return MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=int(os.environ.get("MIMO_LLM_TP", "4")), pipeline_model_parallel_size=int(os.environ.get("MIMO_LLM_PP", "1")), data_parallel_size=int(os.environ.get("MIMO_LLM_DP", "1")), @@ -234,7 +251,9 @@ def _build_config( max_dp = max(p.data_parallel_size for p in par_cfg.module_parallelisms.values()) train_cfg = TrainingConfig( - micro_batch_size=1, global_batch_size=max_dp, train_iters=train_iters, + micro_batch_size=1, + global_batch_size=max_dp, + train_iters=train_iters, ) train_cfg.num_microbatches = 1 train_cfg.grad_reduce_in_fp32 = False @@ -246,7 +265,7 @@ def _build_config( logger_cfg = LoggerConfig() logger_cfg.log_interval = 1 - llm_pp = par_cfg.module_parallelisms["llm"].pipeline_model_parallel_size + llm_pp = par_cfg.module_parallelisms["language"].pipeline_model_parallel_size ckpt_cfg = CheckpointConfig( save_interval=save_interval, save=ckpt_dir, @@ -295,7 +314,7 @@ def _run_phase_save(ckpt_dir: str) -> None: modality_submodules_spec=modality_specs, special_token_ids=special_tokens, mimo_parallelism_config=_build_parallelism_config(), - topology={"vision": ["llm"], "llm": []}, + topology={"vision": ["language"], "language": []}, use_cpu_initialization=True, ) if not hasattr(mimo_provider, "num_moe_experts"): @@ -306,22 +325,33 @@ def _run_phase_save(ckpt_dir: str) -> None: mock_data = _build_mock_data_provider() bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True) mcore_opt = MCoreOptimizerConfig( - optimizer="adam", lr=1e-4, min_lr=0.0, weight_decay=0.01, - clip_grad=1.0, bf16=True, use_distributed_optimizer=True, + optimizer="adam", + lr=1e-4, + min_lr=0.0, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, ) cfg = _build_config( - mimo_provider, mock_data, bridge_opt, ckpt_dir, - train_iters=SAVE_STEPS, save_interval=SAVE_STEPS, + mimo_provider, + mock_data, + bridge_opt, + ckpt_dir, + train_iters=SAVE_STEPS, + save_interval=SAVE_STEPS, ) global_state = GlobalState() pretrain_mimo( - cfg=cfg, mimo_provider=mimo_provider, + cfg=cfg, + mimo_provider=mimo_provider, forward_step_func=mimo_forward_step, build_data_iterators_fn=_build_data_iterators, - opt_config=mcore_opt, schedulers={}, + opt_config=mcore_opt, + schedulers={}, global_state=global_state, ) @@ -360,7 +390,7 @@ def _run_phase_resume(ckpt_dir: str) -> None: modality_submodules_spec=modality_specs, special_token_ids=special_tokens, mimo_parallelism_config=_build_parallelism_config(), - topology={"vision": ["llm"], "llm": []}, + topology={"vision": ["language"], "language": []}, use_cpu_initialization=True, ) if not hasattr(mimo_provider, "num_moe_experts"): @@ -371,13 +401,22 @@ def _run_phase_resume(ckpt_dir: str) -> None: mock_data = _build_mock_data_provider() bridge_opt = BridgeOptimizerConfig(lr=1e-4, use_distributed_optimizer=True) mcore_opt = MCoreOptimizerConfig( - optimizer="adam", lr=1e-4, min_lr=0.0, weight_decay=0.01, - clip_grad=1.0, bf16=True, use_distributed_optimizer=True, + optimizer="adam", + lr=1e-4, + min_lr=0.0, + weight_decay=0.01, + clip_grad=1.0, + bf16=True, + use_distributed_optimizer=True, ) cfg = _build_config( - mimo_provider, mock_data, bridge_opt, ckpt_dir, - train_iters=TOTAL_STEPS, save_interval=TOTAL_STEPS, + mimo_provider, + mock_data, + bridge_opt, + ckpt_dir, + train_iters=TOTAL_STEPS, + save_interval=TOTAL_STEPS, load_dir=ckpt_dir, ) # Save phase used train_iters=SAVE_STEPS, so checkpoint scheduler state @@ -390,10 +429,12 @@ def _run_phase_resume(ckpt_dir: str) -> None: global_state = GlobalState() pretrain_mimo( - cfg=cfg, mimo_provider=mimo_provider, + cfg=cfg, + mimo_provider=mimo_provider, forward_step_func=mimo_forward_step, build_data_iterators_fn=_build_data_iterators, - opt_config=mcore_opt, schedulers={}, + opt_config=mcore_opt, + schedulers={}, global_state=global_state, ) @@ -402,9 +443,7 @@ def _run_phase_resume(ckpt_dir: str) -> None: _log(f"Phase RESUME complete: step={ts.step}, consumed_train_samples={ts.consumed_train_samples}") # Verify step continuity - assert ts.step == TOTAL_STEPS, ( - f"Step continuity failed: expected {TOTAL_STEPS}, got {ts.step}" - ) + assert ts.step == TOTAL_STEPS, f"Step continuity failed: expected {TOTAL_STEPS}, got {ts.step}" # Verify consumed_train_samples did not reset to 0 assert ts.consumed_train_samples >= saved_marker["consumed_train_samples"], ( diff --git a/tests/e2e/mimo/test_mimo_training_e2e.py b/tests/e2e/mimo/test_mimo_training_e2e.py index d9e38dc7bb..7564794a19 100644 --- a/tests/e2e/mimo/test_mimo_training_e2e.py +++ b/tests/e2e/mimo/test_mimo_training_e2e.py @@ -124,7 +124,7 @@ def _build_model_specs(): def _build_parallelism_config() -> MimoParallelismConfig: return MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=4, pipeline_model_parallel_size=1, data_parallel_size=1, @@ -332,7 +332,7 @@ def main(): modality_submodules_spec=modality_submodules_spec, special_token_ids=special_token_ids, mimo_parallelism_config=mimo_parallelism_config, - topology={"vision": ["llm"], "llm": []}, + topology={"vision": ["language"], "language": []}, use_cpu_initialization=True, ) if not hasattr(mimo_provider, "num_moe_experts"): diff --git a/tests/e2e/mimo/test_mimo_training_llava.py b/tests/e2e/mimo/test_mimo_training_llava.py index 97b0006faa..5d82ea1323 100644 --- a/tests/e2e/mimo/test_mimo_training_llava.py +++ b/tests/e2e/mimo/test_mimo_training_llava.py @@ -193,7 +193,7 @@ def _build_model_specs(): def _build_parallelism_config() -> MimoParallelismConfig: return MimoParallelismConfig( module_parallelisms={ - "llm": ModuleParallelismConfig( + "language": ModuleParallelismConfig( tensor_model_parallel_size=4, pipeline_model_parallel_size=1, data_parallel_size=1, @@ -493,7 +493,7 @@ def main(): modality_submodules_spec=modality_submodules_spec, special_token_ids=special_token_ids, mimo_parallelism_config=mimo_parallelism_config, - topology={"images": ["llm"], "llm": []}, + topology={"images": ["language"], "language": []}, use_cpu_initialization=True, bf16=True, ) diff --git a/tests/unit_tests/training/mimo/test_mimo_checkpointing.py b/tests/unit_tests/training/mimo/test_mimo_checkpointing.py index ddff0bcc1e..f7381c2f49 100644 --- a/tests/unit_tests/training/mimo/test_mimo_checkpointing.py +++ b/tests/unit_tests/training/mimo/test_mimo_checkpointing.py @@ -37,7 +37,7 @@ def _make_mimo_infra(*, num_active_pgs: int = 1) -> Mock: for i in range(num_active_pgs): pgs[f"module_{i}"] = Mock() infra.pg_collections = pgs - infra.module_to_grid_map = {"llm": Mock()} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() return infra @@ -251,9 +251,9 @@ def test_setup_mimo_initializes_checkpointing_context( provider = Mock() infra = Mock() - infra.module_to_grid_map = {"llm": Mock()} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() - infra.pg_collections = {"llm": Mock()} + infra.pg_collections = {"language": Mock()} provider.build_infra.return_value = infra provider.provide_distributed_model.return_value = [Mock()] @@ -380,8 +380,8 @@ def test_calls_checkpoint_and_decide_exit_with_pg_collection( pg = Mock() infra = Mock() - infra.pg_collections = {"llm": pg} - infra.module_to_grid_map = {"llm": Mock()} + infra.pg_collections = {"language": pg} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() mock_build_pg.return_value = Mock(spec=[]) # not a list @@ -394,7 +394,7 @@ def test_calls_checkpoint_and_decide_exit_with_pg_collection( forward_step_func=Mock(), model=Mock(), optimizer=Mock(), - schedulers={"llm": _make_scheduler_mock()}, + schedulers={"language": _make_scheduler_mock()}, train_data_iterator=train_iter, valid_data_iterator=None, global_state=state, @@ -441,8 +441,8 @@ def test_exits_loop_when_checkpoint_and_decide_exit_returns_true( mock_get_config.return_value = mock_config infra = Mock() - infra.pg_collections = {"llm": Mock()} - infra.module_to_grid_map = {"llm": Mock()} + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() mock_build_pg.return_value = Mock(spec=[]) @@ -452,7 +452,7 @@ def test_exits_loop_when_checkpoint_and_decide_exit_returns_true( forward_step_func=Mock(), model=Mock(), optimizer=Mock(), - schedulers={"llm": _make_scheduler_mock()}, + schedulers={"language": _make_scheduler_mock()}, train_data_iterator=Mock(), valid_data_iterator=None, global_state=state, @@ -496,8 +496,8 @@ def test_async_finalize_called_at_top_of_loop( mock_get_config.return_value = mock_config infra = Mock() - infra.pg_collections = {"llm": Mock()} - infra.module_to_grid_map = {"llm": Mock()} + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() mock_build_pg.return_value = Mock(spec=[]) @@ -507,7 +507,7 @@ def test_async_finalize_called_at_top_of_loop( forward_step_func=Mock(), model=Mock(), optimizer=Mock(), - schedulers={"llm": _make_scheduler_mock()}, + schedulers={"language": _make_scheduler_mock()}, train_data_iterator=Mock(), valid_data_iterator=None, global_state=state, @@ -519,12 +519,8 @@ def test_async_finalize_called_at_top_of_loop( # 2 non-blocking calls (top of each iteration) + 1 blocking call (shutdown) assert mock_async_finalize.call_count == 3 - non_blocking_calls = [ - c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is False - ] - blocking_calls = [ - c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is True - ] + non_blocking_calls = [c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is False] + blocking_calls = [c for c in mock_async_finalize.call_args_list if c.kwargs.get("blocking") is True] assert len(non_blocking_calls) == 2 assert len(blocking_calls) == 1 assert blocking_calls[0].kwargs.get("terminate") is True @@ -562,8 +558,8 @@ def test_no_inline_save_checkpoint_call( mock_get_config.return_value = mock_config infra = Mock() - infra.pg_collections = {"llm": Mock()} - infra.module_to_grid_map = {"llm": Mock()} + infra.pg_collections = {"language": Mock()} + infra.module_to_grid_map = {"language": Mock()} infra.topology = Mock() mock_build_pg.return_value = Mock(spec=[]) @@ -574,7 +570,7 @@ def test_no_inline_save_checkpoint_call( forward_step_func=Mock(), model=Mock(), optimizer=Mock(), - schedulers={"llm": _make_scheduler_mock()}, + schedulers={"language": _make_scheduler_mock()}, train_data_iterator=Mock(), valid_data_iterator=None, global_state=state, @@ -605,7 +601,7 @@ def _make_setup_output_for_load( ) -> SimpleNamespace: """Create a MimoSetupOutput-like namespace suitable for pretrain_mimo load tests.""" if pg_collections is None: - pg_collections = {"llm": Mock()} + pg_collections = {"language": Mock()} train_state = SimpleNamespace( step=train_state_step, @@ -623,7 +619,7 @@ def _make_setup_output_for_load( return SimpleNamespace( model=MagicMock(), mimo_infra=SimpleNamespace( - module_to_grid_map={"llm": Mock()}, + module_to_grid_map={"language": Mock()}, pg_collections=pg_collections, topology=Mock(), ), @@ -715,7 +711,7 @@ def _run_pretrain_mimo( m_dist.get_rank.return_value = 0 m_dist.get_world_size.return_value = 2 m_unwrap.return_value = MagicMock( - mimo_config=SimpleNamespace(module_to_grid_map={"llm": Mock()}, language_module_key="llm"), + mimo_config=SimpleNamespace(module_to_grid_map={"language": Mock()}), ) mock_optimizer = MagicMock() mock_optimizer.module_infos = {} @@ -775,7 +771,9 @@ def test_load_forwards_list_wrapped_model(self): cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") setup_output = _make_setup_output_for_load() mocks = _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) _, kwargs = mocks["load_checkpoint"].call_args assert isinstance(kwargs["model"], list) @@ -784,10 +782,12 @@ def test_load_forwards_list_wrapped_model(self): def test_load_forwards_explicit_pg_collection(self): pg = Mock() - setup_output = _make_setup_output_for_load(pg_collections={"llm": pg}) + setup_output = _make_setup_output_for_load(pg_collections={"language": pg}) cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") mocks = _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) _, kwargs = mocks["load_checkpoint"].call_args assert kwargs["pg_collection"] is pg @@ -796,7 +796,9 @@ def test_load_forwards_checkpointing_context(self): setup_output = _make_setup_output_for_load() cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") mocks = _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) _, kwargs = mocks["load_checkpoint"].call_args assert kwargs["checkpointing_context"] is setup_output.checkpointing_context @@ -804,10 +806,12 @@ def test_load_forwards_checkpointing_context(self): def test_load_forwards_first_scheduler(self): sched_a = _make_scheduler_mock() sched_b = _make_scheduler_mock() - schedulers = {"llm": sched_a, "vision": sched_b} + schedulers = {"language": sched_a, "vision": sched_b} cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") mocks = _run_pretrain_mimo( - cfg=cfg, schedulers=schedulers, checkpoint_exists_return=True, + cfg=cfg, + schedulers=schedulers, + checkpoint_exists_return=True, ) _, kwargs = mocks["load_checkpoint"].call_args assert kwargs["opt_param_scheduler"] is sched_a @@ -826,17 +830,21 @@ def test_rejects_zero_active_pgs_in_pretrain(self): cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) def test_rejects_multiple_active_pgs_in_pretrain(self): setup_output = _make_setup_output_for_load( - pg_collections={"llm": Mock(), "vision": Mock()}, + pg_collections={"language": Mock(), "vision": Mock()}, ) cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") with pytest.raises(AssertionError, match="exactly one active ProcessGroupCollection"): _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) @@ -857,7 +865,7 @@ def test_scheduler_fanout_after_load(self): sched_b = MagicMock() sched_b.optimizer.param_groups = [{"lr": 1e-4}] - schedulers = {"llm": sched_a, "vision": sched_b} + schedulers = {"language": sched_a, "vision": sched_b} cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") # Simulate load succeeding and setting step > 0 via side_effect. @@ -876,7 +884,7 @@ def test_no_fanout_when_single_scheduler(self): sched.optimizer.param_groups = [{"lr": 1e-4}] sched.state_dict.return_value = {"step": 50} - schedulers = {"llm": sched} + schedulers = {"language": sched} cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") _run_pretrain_mimo(cfg=cfg, schedulers=schedulers, checkpoint_exists_return=True) @@ -911,7 +919,6 @@ def test_iterator_builder_receives_train_state_mock(self): """When resuming (step > 0), builder receives train_state kwarg.""" build_fn = MagicMock(return_value=(iter([]), None)) # Give mock the train_state parameter so inspect.signature finds it - import types def _sig_fn(cfg, mimo_infra, *, train_state=None): pass @@ -970,7 +977,7 @@ def _make_mimo_optimizer(self): opt_b = MagicMock() module_infos = { - "llm": ModuleOptimizerInfo(optimizer=opt_a, grid=Mock(), pg_collection=Mock(), is_active=True), + "language": ModuleOptimizerInfo(optimizer=opt_a, grid=Mock(), pg_collection=Mock(), is_active=True), "vision": ModuleOptimizerInfo(optimizer=opt_b, grid=Mock(), pg_collection=Mock(), is_active=True), } config = MagicMock() @@ -978,14 +985,14 @@ def _make_mimo_optimizer(self): def test_load_state_dict_dispatches_per_module(self): mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() - state = {"llm": {"param": 1}, "vision": {"param": 2}} + state = {"language": {"param": 1}, "vision": {"param": 2}} mimo_opt.load_state_dict(state) opt_a.load_state_dict.assert_called_once_with({"param": 1}) opt_b.load_state_dict.assert_called_once_with({"param": 2}) def test_load_state_dict_skips_missing_keys(self): mimo_opt, opt_a, opt_b = self._make_mimo_optimizer() - state = {"llm": {"param": 1}} + state = {"language": {"param": 1}} mimo_opt.load_state_dict(state) opt_a.load_state_dict.assert_called_once() opt_b.load_state_dict.assert_not_called() @@ -996,9 +1003,9 @@ def test_sharded_state_dict_generates_per_module(self): opt_b.sharded_state_dict.return_value = {"b": "sharded_b"} result = mimo_opt.sharded_state_dict({}, is_loading=True) - assert "llm" in result + assert "language" in result assert "vision" in result - assert result["llm"] == {"a": "sharded_a"} + assert result["language"] == {"a": "sharded_a"} assert result["vision"] == {"b": "sharded_b"} def test_reload_model_params_delegates_to_all_active(self): @@ -1011,7 +1018,7 @@ def test_is_stub_optimizer_when_no_active(self): from megatron.core.models.mimo.optimizer import MimoOptimizer, ModuleOptimizerInfo module_infos = { - "llm": ModuleOptimizerInfo(optimizer=None, grid=Mock(), pg_collection=Mock(), is_active=False), + "language": ModuleOptimizerInfo(optimizer=None, grid=Mock(), pg_collection=Mock(), is_active=False), } mimo_opt = MimoOptimizer(module_infos, MagicMock()) assert mimo_opt.is_stub_optimizer is True @@ -1052,7 +1059,9 @@ def test_floating_point_ops_preserved(self): setup_output = _make_setup_output_for_load(floating_point_operations_so_far=99999) cfg = _make_pretrain_cfg(load_path="/tmp/ckpt") mocks = _run_pretrain_mimo( - cfg=cfg, setup_output=setup_output, checkpoint_exists_return=True, + cfg=cfg, + setup_output=setup_output, + checkpoint_exists_return=True, ) _, kwargs = mocks["train_mimo"].call_args assert kwargs["global_state"].train_state.floating_point_operations_so_far == 99999 diff --git a/tests/unit_tests/training/mimo/test_pretrain_mimo.py b/tests/unit_tests/training/mimo/test_pretrain_mimo.py index db5730d544..a9fb75c03c 100644 --- a/tests/unit_tests/training/mimo/test_pretrain_mimo.py +++ b/tests/unit_tests/training/mimo/test_pretrain_mimo.py @@ -29,7 +29,7 @@ def _make_setup_output(module_to_grid_map): model=MagicMock(), mimo_infra=SimpleNamespace( module_to_grid_map=module_to_grid_map, - pg_collections={"llm": MagicMock()}, + pg_collections={"language": MagicMock()}, ), multimodule_communicator=MagicMock(), train_data_iterator=iter([]),