feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism#2869
feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism#2869aroshanghias-nvd wants to merge 6 commits intomainfrom
Conversation
…rallelism Squash of all Phase 4 MiMo work from mimo/phase4-training (47870e4), rebased onto upstream/main at f1fb06a. Includes: - MimoModelProvider with ModuleSpec-based API and heterogeneous LLaVA support - MiMo training loop (pretrain_mimo, train_mimo, mimo_step) - Heterogeneous TP/PP/DP parallelism plumbing (mimo_parallel_utils) - MiMo data loading (collate, dataset, loaders, hf_provider, mock_provider) - Data loader dispatch routing for MIMO models (loaders.py) - MiMo DDP wrapping and model builder - Kamran's loss mask and heterogeneous LLaVA dataset support - Megatron-LM submodule pinned to PR #3212 head - Full unit test coverage (provider, config, step, collate, pretrain tests) Phase 5 (checkpointing/evaluation) is stacked in a separate branch. Original commit history preserved in backup/mimo-phase4-training-v0 (47870e4). Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
ee23945 to
9642406
Compare
0aea6d8 to
9642406
Compare
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test 5e0b771 |
📝 WalkthroughWalkthroughThis pull request introduces comprehensive MIMO heterogeneous parallel training infrastructure, including refactored data loading with MIMO-specific paths, new abstract dataset provider interfaces, loss mask tensor handling, updated distributed grid and process group management, and new training entrypoints with multi-module gradient synchronization support. Changes
Sequence DiagramssequenceDiagram
participant Client as Training Client
participant DataLoader as MIMO DataLoader
participant Provider as MimoDatasetProvider
participant Dataset as MimoDataset
participant Collate as mimo_collate_fn
participant Trainer as Trainer
Client->>DataLoader: build_mimo_data_loaders()
DataLoader->>Provider: build_datasets(context)
Provider-->>DataLoader: (train_ds, valid_ds, test_ds)
DataLoader->>Collate: (batches from DataLoader)
Collate->>Dataset: Stack loss_mask, input_ids, labels, attention_mask
Collate-->>DataLoader: Collated batch dict
DataLoader-->>Client: (train_loader, valid_loader, test_loader)
Client->>Trainer: train_mimo(model, optimizer, iterators)
Trainer->>Trainer: train_step_mimo()
Trainer->>Trainer: multimodule_no_sync() context
Trainer->>Trainer: forward_backward_pipelining()
Trainer->>Trainer: finalize_model_grads_multimodule()
Trainer->>Trainer: optimizer.step()
Trainer-->>Client: loss_dict, grad_norm
sequenceDiagram
participant Rank as Distributed Rank
participant MIMO as MimoModel
participant ForwardStep as forward_step()
participant LossFn as loss_func()
participant DataIter as DataIterator
participant Alloc as Memory/Alloc
Rank->>ForwardStep: (state, data_iter, model)
ForwardStep->>MIMO: unwrap_mimo_model()
MIMO-->>ForwardStep: model ref
ForwardStep->>DataIter: Check if rank needs data
alt Rank needs data
ForwardStep->>DataIter: get_batch()
DataIter-->>ForwardStep: {loss_mask, input_ids, labels, ...}
else Rank doesn't need data
ForwardStep->>Alloc: Create null placeholders
Alloc-->>ForwardStep: {input_ids: None, ...}
end
ForwardStep->>MIMO: model(**batch)
MIMO-->>ForwardStep: output_tensor
alt Last PP stage
ForwardStep->>LossFn: Return partial(loss_func, loss_mask)
else Intermediate stage
ForwardStep-->>Rank: (output_tensor, None)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unit_tests/models/mimo/test_mimo_builder.py (1)
122-125:⚠️ Potential issue | 🟡 MinorStale comment: "llm" should be "language".
The comment still references "llm" but the module key was changed to "language".
📝 Suggested fix
- # First call (llm): shape [8, 1, 1, 2, 1] + # First call (language): shape [8, 1, 1, 2, 1] first_call_kwargs = mock_grid_class.call_args_list[0][1] assert first_call_kwargs["shape"] == [8, 1, 1, 2, 1] assert first_call_kwargs["rank_offset"] == 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/models/mimo/test_mimo_builder.py` around lines 122 - 125, Update the stale inline comment that reads "# First call (llm): shape [8, 1, 1, 2, 1]" to reference the new module key name "language" instead of "llm"; locate the comment near the assertions using first_call_kwargs and mock_grid_class.call_args_list in test_mimo_builder.py and change it to something like "# First call (language): shape [8, 1, 1, 2, 1]" so the comment matches the updated module key.
🧹 Nitpick comments (8)
tests/unit_tests/data/mimo/test_dp_utils.py (1)
78-81: Consider underscore prefix for unused unpacked variables.In tests that don't assert on
dp_rankanddp_size, consider using underscores to silence Ruff warnings and signal intent.♻️ Suggested fix (optional)
- dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids) + _dp_rank, _dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids)Apply similar changes at lines 94, 110, and 126 where
dp_rank/dp_sizeare not asserted.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/data/mimo/test_dp_utils.py` around lines 78 - 81, The test unpacks get_mimo_dp_info into dp_rank and dp_size but never uses them, causing lint warnings; change the unused variables to _dp_rank and _dp_size (or simply _ and _dp_size) in the test(s) that don't assert them (the unpack in tests/unit_tests/data/mimo/test_dp_utils.py around the get_mimo_dp_info calls at the locations mentioned) so the intent is clear and Ruff warnings are silenced while keeping loader_module and needs_data assertions unchanged.src/megatron/bridge/data/mimo/loaders.py (1)
83-83: Unused variable: prefixloader_modulewith underscore.The
loader_modulevalue is unpacked but never used. Per static analysis hint (RUF059), prefix it with an underscore.♻️ Suggested fix
- dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) + dp_rank, dp_size, needs_data, _loader_module = get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/mimo/loaders.py` at line 83, The unpacked variable loader_module from the call to get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) is unused; rename it to _loader_module (or prefix with an underscore) to satisfy the unused-variable lint rule (RUF059). Update the assignment dp_rank, dp_size, needs_data, loader_module to dp_rank, dp_size, needs_data, _loader_module wherever that pattern appears (e.g., in the function or module that calls get_mimo_dp_info) so the unused value is clearly marked.src/megatron/bridge/data/mimo/hf_provider.py (1)
8-8: Prefer modern union syntax overUnionandOptional.Per coding guidelines, use
X | Yfor union types instead ofUnion[X, Y], andT | Noneinstead ofOptional[T].♻️ Suggested fix
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, TupleThen update line 77:
- hf_data_files: Optional[Union[str, List[str]]] = None + hf_data_files: str | List[str] | None = NoneAs per coding guidelines: "Use
X | Yfor union types instead ofUnion[X, Y]" and "UseT | Nonefor nullable types instead ofOptional[T]".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/mimo/hf_provider.py` at line 8, The file imports Optional and Union from typing; per guidelines remove Optional and Union from the import list and replace their uses with modern PEP 604 syntax (T | None for optional types and A | B for unions). Update the import line in hf_provider.py to drop Optional and Union and then replace every type annotation in this module (including the annotation referenced around line 77) that uses Union[...] or Optional[...] to use the pipe syntax instead so function/method signatures and return types (wherever present in this module) use X | Y and T | None.tests/unit_tests/training/mimo/test_mimo_step.py (2)
30-52: Prefix unusedmetricsvariable with underscore to satisfy linter.The
metricsvariable is unpacked but not used in these tests. Use_prefix to indicate intentional non-use.♻️ Suggested fix
- total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor) + total_loss, num_tokens, _metrics = loss_func(loss_mask, output_tensor)Apply similarly to line 49.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/mimo/test_mimo_step.py` around lines 30 - 52, In the two tests test_loss_with_all_ones_mask and test_loss_with_all_zeros_mask, the unpacked third return value from loss_func is stored in an unused variable named metrics; rename it to _metrics (or simply _ ) to satisfy the linter and indicate intentional non-use—update the unpacking in both tests where loss_func(loss_mask, output_tensor) is called so it reads e.g. total_loss, num_tokens, _metrics = loss_func(...)
111-111: Prefix unusedoutputvariable with underscore.Per static analysis, the
outputvariable is unpacked but unused.♻️ Suggested fix
- output, loss_fn = forward_step(mock_state, data_iter, mock_model) + _output, loss_fn = forward_step(mock_state, data_iter, mock_model)Apply similarly to line 138.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/mimo/test_mimo_step.py` at line 111, The unpacked variable "output" from the call to forward_step(mock_state, data_iter, mock_model) is unused; rename it to "_output" to satisfy static analysis and avoid unused-variable warnings, and make the same change for the other forward_step unpack at the other location where "output" is unused (apply the same "_output" prefix there); ensure you only rename the local variable, not the call or other identifiers.tests/unit_tests/training/mimo/test_pretrain_mimo.py (1)
88-88: Inconsistent module key:"llm"vs"language".This test uses
"llm"as the module key inmodule_to_grid_map, but the rest of this PR consistently uses"language"(which matchesMIMO_LANGUAGE_MODULE_KEY). While the seed setting logic iterates over all keys and doesn't require a specific name, using"language"would be more consistent.♻️ Suggested fix
- mimo_infra = SimpleNamespace(module_to_grid_map={"llm": grid}) + mimo_infra = SimpleNamespace(module_to_grid_map={"language": grid})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/mimo/test_pretrain_mimo.py` at line 88, Replace the inconsistent module key "llm" in the test's mimo_infra.module_to_grid_map with the canonical key used throughout the PR ("language" / MIMO_LANGUAGE_MODULE_KEY) so the test matches the rest of the codebase; update the SimpleNamespace initialization (mimo_infra = SimpleNamespace(module_to_grid_map={"llm": grid})) to use "language" as the map key to align with the seed-setting logic and naming elsewhere.src/megatron/bridge/data/mimo/base_provider.py (1)
6-8: Use Python 3.10+ type hint syntax.Per coding guidelines, use
T | Noneinstead ofOptional[T]and built-intupleinstead oftyping.Tuplefor Python 3.10+.♻️ Suggested fix
-from typing import Callable, Optional, Tuple +from typing import CallableAnd update the return type on line 41:
- ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]: + ) -> tuple[Dataset | None, Dataset | None, Dataset | None]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/mimo/base_provider.py` around lines 6 - 8, Replace typing.Optional and typing.Tuple usages with Python 3.10+ syntax: remove Optional and Tuple from imports and change annotations like Optional[T] to T | None and Tuple[A, B] to tuple[A, B]; update any Callable[...] imports/uses to remain as-is if needed. In particular, update the return type annotation referenced on line 41 to use the new native syntax (e.g., change Optional[X] to X | None or Tuple[...] to tuple[...]) and ensure imports only include abstractmethod and dataclass (and Callable if still used).src/megatron/bridge/data/mimo/dp_utils.py (1)
6-6: Use built-intuplefor type hints per Python 3.10+ guidelines.♻️ Suggested fix
-from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKINGAnd update the return type annotation on line 21:
-) -> Tuple[int, int, bool, str]: +) -> tuple[int, int, bool, str]:Note:
Dictimport can also be removed if the type hint on line 20 uses the built-indictinstead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/mimo/dp_utils.py` at line 6, Replace typing imports Dict and Tuple with built-in types: remove Dict and Tuple from the import line (keep TYPE_CHECKING) and update all type hints that use Dict[...] and Tuple[...] to use built-in dict[...] and tuple[...] instead; specifically change the return type annotation on the function currently annotated with Tuple/Dict to use tuple and dict, and remove the now-unused Dict/Tuple imports in src/megatron/bridge/data/mimo/dp_utils.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/data/loaders.py`:
- Around line 205-207: The do_valid and do_test flags are using
cfg.train.eval_iters but the module reads evaluation settings from
cfg.validation; update the checks in the block that sets
do_train/do_valid/do_test so that do_valid and do_test use
cfg.validation.eval_iters (or the appropriate cfg.validation field) instead of
cfg.train.eval_iters, keeping do_train unchanged; adjust references to the
variables do_valid and do_test in the surrounding code if needed to ensure
evaluation behavior follows cfg.validation.eval_iters.
In `@src/megatron/bridge/data/mimo/collate.py`:
- Around line 62-63: The docstring for mimo_collate_fn is missing documentation
for the newly collated "loss_mask" field; update the mimo_collate_fn docstring
(Args, Returns and example sections) to list loss_mask as a required tensor in
the collated payload, describe its shape and purpose (e.g., mask for loss
computation), and show it in the example return structure alongside existing
keys so external callers see the API contract.
In `@src/megatron/bridge/models/mimo/mimo_provider.py`:
- Around line 412-413: setup_mimo() already constructs a MimoModelInfra but
build_infra() is being called again here and later in provide(), causing
multiple, inconsistent infra/communicator/DP wrappers; change provide() to
accept a MimoModelInfra (e.g., provide(self, infra: MimoModelInfra, ...)) or add
a private helper (e.g., _provide_with_infra(self, infra, ...)) that constructs
the model from the given infra instead of calling build_infra(); then thread the
existing infra from setup_mimo() into that call and remove the extra
build_infra() invocation so the same infra/communicator/_grids are used across
construction and DDP wrapping.
- Line 201: The call to populate_embedding_and_position_groups(pp_group) must be
executed on all ranks in the same global order because it invokes
torch.distributed.new_group; move the call out of the module participation
conditional so every process calls
populate_embedding_and_position_groups(pp_group) in the same sequence, then only
assign its returned handles (pos_embd_pg, embd_pg) to the module-local variables
for ranks that actually participate in the module; keep the pp_group argument
and preserve use of the returned handles but ensure non-participating ranks
discard or set them to None while still calling the helper to maintain
collective ordering.
In `@src/megatron/bridge/training/mimo_parallel_utils.py`:
- Around line 173-209: The wrapper finalize_model_grads_multimodule currently
discards the schedule-provided force_all_reduce; update the call so the flag is
forwarded into _finalize_model_grads (i.e., call _finalize_model_grads([module],
num_tokens=num_tokens, pg_collection=module_pg,
force_all_reduce=force_all_reduce)) so schedule-controlled behavior is
preserved; locate this change inside finalize_model_grads_multimodule where
module_pg is not None and ensure the symbol force_all_reduce is passed through
unchanged from the wrapper signature that includes infra and
module_to_grid_tuple.
In `@src/megatron/bridge/training/pretrain_mimo.py`:
- Around line 58-65: The seed assignment treats 0 as unset because it relies on
truthiness; change the logic in the seed initialization so you check for None
explicitly: prefer cfg.seed when cfg has attribute 'seed' even if it's 0,
otherwise fall back to cfg.rng.seed if that attribute exists and is not None,
and finally default to 1234; update the expression that defines the variable
seed (the current getattr chain) to use explicit "is not None" checks or a small
helper that returns the first non-None of cfg.seed, getattr(cfg, 'rng',
None).seed, and 1234 so reproducible seeds like 0 are preserved.
In `@src/megatron/bridge/training/train_mimo.py`:
- Around line 164-185: The broadcast currently only updates loss_dict on the
logging rank when source_rank != last_rank and skips broadcast when source_rank
== last_rank, leaving other ranks with empty dicts; change the logic to always
call torch.distributed.broadcast_object_list whenever source_rank >= 0 (use
src=source_rank) so every rank receives the same object, then on every rank set
loss_dict = received_obj or {} and move any tensor values to the local CUDA
device (reuse the existing received->loss_dict assignment and tensor .cuda()
conversion logic); update references: last_rank, my_rank, source_rank,
loss_dict, and broadcast_object_list in train_mimo.py so training_log sees
identical loss keys on all ranks.
- Around line 397-402: The condition uses train_config.eval_interval but the
project standard is cfg.validation.*; update the evaluation gating in the MIMO
loop to read eval_interval from the validation config (e.g., use validation_cfg
= cfg.validation or cfg.validation.eval_interval) instead of
train_config.eval_interval, and change the check to use
validation_cfg.eval_interval (ensure the variable is in scope) in the clause
that currently references train_state.step % train_config.eval_interval and the
None check, leaving train_state.step and valid_data_iterator as-is.
---
Outside diff comments:
In `@tests/unit_tests/models/mimo/test_mimo_builder.py`:
- Around line 122-125: Update the stale inline comment that reads "# First call
(llm): shape [8, 1, 1, 2, 1]" to reference the new module key name "language"
instead of "llm"; locate the comment near the assertions using first_call_kwargs
and mock_grid_class.call_args_list in test_mimo_builder.py and change it to
something like "# First call (language): shape [8, 1, 1, 2, 1]" so the comment
matches the updated module key.
---
Nitpick comments:
In `@src/megatron/bridge/data/mimo/base_provider.py`:
- Around line 6-8: Replace typing.Optional and typing.Tuple usages with Python
3.10+ syntax: remove Optional and Tuple from imports and change annotations like
Optional[T] to T | None and Tuple[A, B] to tuple[A, B]; update any Callable[...]
imports/uses to remain as-is if needed. In particular, update the return type
annotation referenced on line 41 to use the new native syntax (e.g., change
Optional[X] to X | None or Tuple[...] to tuple[...]) and ensure imports only
include abstractmethod and dataclass (and Callable if still used).
In `@src/megatron/bridge/data/mimo/dp_utils.py`:
- Line 6: Replace typing imports Dict and Tuple with built-in types: remove Dict
and Tuple from the import line (keep TYPE_CHECKING) and update all type hints
that use Dict[...] and Tuple[...] to use built-in dict[...] and tuple[...]
instead; specifically change the return type annotation on the function
currently annotated with Tuple/Dict to use tuple and dict, and remove the
now-unused Dict/Tuple imports in src/megatron/bridge/data/mimo/dp_utils.py.
In `@src/megatron/bridge/data/mimo/hf_provider.py`:
- Line 8: The file imports Optional and Union from typing; per guidelines remove
Optional and Union from the import list and replace their uses with modern PEP
604 syntax (T | None for optional types and A | B for unions). Update the import
line in hf_provider.py to drop Optional and Union and then replace every type
annotation in this module (including the annotation referenced around line 77)
that uses Union[...] or Optional[...] to use the pipe syntax instead so
function/method signatures and return types (wherever present in this module)
use X | Y and T | None.
In `@src/megatron/bridge/data/mimo/loaders.py`:
- Line 83: The unpacked variable loader_module from the call to
get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) is unused; rename it
to _loader_module (or prefix with an underscore) to satisfy the unused-variable
lint rule (RUF059). Update the assignment dp_rank, dp_size, needs_data,
loader_module to dp_rank, dp_size, needs_data, _loader_module wherever that
pattern appears (e.g., in the function or module that calls get_mimo_dp_info) so
the unused value is clearly marked.
In `@tests/unit_tests/data/mimo/test_dp_utils.py`:
- Around line 78-81: The test unpacks get_mimo_dp_info into dp_rank and dp_size
but never uses them, causing lint warnings; change the unused variables to
_dp_rank and _dp_size (or simply _ and _dp_size) in the test(s) that don't
assert them (the unpack in tests/unit_tests/data/mimo/test_dp_utils.py around
the get_mimo_dp_info calls at the locations mentioned) so the intent is clear
and Ruff warnings are silenced while keeping loader_module and needs_data
assertions unchanged.
In `@tests/unit_tests/training/mimo/test_mimo_step.py`:
- Around line 30-52: In the two tests test_loss_with_all_ones_mask and
test_loss_with_all_zeros_mask, the unpacked third return value from loss_func is
stored in an unused variable named metrics; rename it to _metrics (or simply _ )
to satisfy the linter and indicate intentional non-use—update the unpacking in
both tests where loss_func(loss_mask, output_tensor) is called so it reads e.g.
total_loss, num_tokens, _metrics = loss_func(...)
- Line 111: The unpacked variable "output" from the call to
forward_step(mock_state, data_iter, mock_model) is unused; rename it to
"_output" to satisfy static analysis and avoid unused-variable warnings, and
make the same change for the other forward_step unpack at the other location
where "output" is unused (apply the same "_output" prefix there); ensure you
only rename the local variable, not the call or other identifiers.
In `@tests/unit_tests/training/mimo/test_pretrain_mimo.py`:
- Line 88: Replace the inconsistent module key "llm" in the test's
mimo_infra.module_to_grid_map with the canonical key used throughout the PR
("language" / MIMO_LANGUAGE_MODULE_KEY) so the test matches the rest of the
codebase; update the SimpleNamespace initialization (mimo_infra =
SimpleNamespace(module_to_grid_map={"llm": grid})) to use "language" as the map
key to align with the seed-setting logic and naming elsewhere.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c2643fa1-0ac5-41ab-a44a-06d77567647a
📒 Files selected for processing (34)
3rdparty/Megatron-LMsrc/megatron/bridge/data/loaders.pysrc/megatron/bridge/data/mimo/__init__.pysrc/megatron/bridge/data/mimo/base_provider.pysrc/megatron/bridge/data/mimo/collate.pysrc/megatron/bridge/data/mimo/dataset.pysrc/megatron/bridge/data/mimo/dp_utils.pysrc/megatron/bridge/data/mimo/hf_provider.pysrc/megatron/bridge/data/mimo/loaders.pysrc/megatron/bridge/data/mimo/mock_provider.pysrc/megatron/bridge/models/mimo/llava_provider.pysrc/megatron/bridge/models/mimo/mimo_builder.pysrc/megatron/bridge/models/mimo/mimo_config.pysrc/megatron/bridge/models/mimo/mimo_ddp.pysrc/megatron/bridge/models/mimo/mimo_provider.pysrc/megatron/bridge/training/mimo_parallel_utils.pysrc/megatron/bridge/training/mimo_step.pysrc/megatron/bridge/training/pretrain_mimo.pysrc/megatron/bridge/training/train_mimo.pysrc/megatron/bridge/training/utils/train_utils.pytests/unit_tests/data/mimo/test_collate.pytests/unit_tests/data/mimo/test_dataset.pytests/unit_tests/data/mimo/test_dp_utils.pytests/unit_tests/data/mimo/test_hf_provider.pytests/unit_tests/data/mimo/test_loaders.pytests/unit_tests/models/mimo/test_llava_provider.pytests/unit_tests/models/mimo/test_mimo_builder.pytests/unit_tests/models/mimo/test_mimo_ddp.pytests/unit_tests/models/mimo/test_mimo_provider.pytests/unit_tests/training/mimo/__init__.pytests/unit_tests/training/mimo/test_mimo_config.pytests/unit_tests/training/mimo/test_mimo_parallel_utils.pytests/unit_tests/training/mimo/test_mimo_step.pytests/unit_tests/training/mimo/test_pretrain_mimo.py
| do_train = train_dataloader is not None and cfg.train.train_iters > 0 | ||
| do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0 | ||
| do_test = test_dataloader is not None and cfg.train.eval_iters > 0 |
There was a problem hiding this comment.
Use the validation config here.
The rest of this module reads evaluation settings from cfg.validation; using cfg.train.eval_iters on the MIMO path can raise or permanently disable do_valid/do_test.
Suggested fix
- do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0
- do_test = test_dataloader is not None and cfg.train.eval_iters > 0
+ do_valid = valid_dataloader is not None and cfg.validation.eval_iters > 0
+ do_test = test_dataloader is not None and cfg.validation.eval_iters > 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/data/loaders.py` around lines 205 - 207, The do_valid and
do_test flags are using cfg.train.eval_iters but the module reads evaluation
settings from cfg.validation; update the checks in the block that sets
do_train/do_valid/do_test so that do_valid and do_test use
cfg.validation.eval_iters (or the appropriate cfg.validation field) instead of
cfg.train.eval_iters, keeping do_train unchanged; adjust references to the
variables do_valid and do_test in the surrounding code if needed to ensure
evaluation behavior follows cfg.validation.eval_iters.
| loss_mask = torch.stack([item["loss_mask"] for item in batch]) | ||
|
|
There was a problem hiding this comment.
Update mimo_collate_fn docstring to include loss_mask.
Line 62 and Line 114 make loss_mask part of the required collated payload, but the Args/Returns/example contract still omits it.
📝 Suggested doc update
Args:
batch: List of examples from MimoDataset, each containing:
- input_ids: Token IDs with placeholder tokens
- labels: Labels for causal LM training
+ - loss_mask: Per-token loss mask
- attention_mask: Attention mask
- position_ids: Position indices
- modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
@@
Returns:
Dict containing:
- input_ids: (batch, seq) stacked token IDs
- labels: (batch, seq) stacked labels
+ - loss_mask: (batch, seq) stacked per-token loss mask
- attention_mask: (batch, seq) attention mask
- position_ids: (batch, seq) position indices
- modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensorsAs per coding guidelines For interfaces that may be used outside a file, prefer docstrings over comments.
Also applies to: 114-114
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/data/mimo/collate.py` around lines 62 - 63, The docstring
for mimo_collate_fn is missing documentation for the newly collated "loss_mask"
field; update the mimo_collate_fn docstring (Args, Returns and example sections)
to list loss_mask as a required tensor in the collated payload, describe its
shape and purpose (e.g., mask for loss computation), and show it in the example
return structure alongside existing keys so external callers see the API
contract.
|
|
||
| # Create embedding groups for PP > 1 (collective operation on all PP ranks) | ||
| pos_embd_pg, embd_pg = create_embedding_and_position_groups(pp_group) | ||
| pos_embd_pg, embd_pg = populate_embedding_and_position_groups(pp_group) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Does torch.distributed.new_group need to be called by every rank in the default process group, including non-members, and in the same global order?
💡 Result:
Yes, torch.distributed.new_group must be called by every rank in the default process group (the world group), including non-members of the new subgroup. This is required for synchronization via an internal store-based barrier (_store_based_barrier) that ensures consistent global metadata across all ranks about the created process groups. Non-members receive a special NON_GROUP_MEMBER handle (id -100) but still participate in the barrier. It must also be called in the same global order across all ranks to avoid deadlocks or hangs, especially when creating multiple overlapping groups. Divergent order among participants of individual groups can still cause issues in some backends like Gloo, even if participant-local order is consistent. With use_local_synchronization=True, non-participants can skip, but global order consistency is still advised to prevent backend-specific hangs.
Citations:
- 1: https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
- 2:
new_group()returns same group for unique rank inputs pytorch/pytorch#75755 - 3: https://runebook.dev/en/articles/pytorch/distributed/torch.distributed.new_group
- 4: dist.new_group(..., use_local_synchronization=True) hangs with Gloo for overlapping groups, while NCCL succeeds on the same participant-consistent creation order pytorch/pytorch#177959
- 5: torch.distributed.new_group got stuck pytorch/pytorch#134314
- 6: Lack of clear story around sub process group handling in torch.distributed pytorch/pytorch#81291
🏁 Script executed:
# First, let's examine the relevant section of mimo_provider.py around line 201
head -n 230 src/megatron/bridge/models/mimo/mimo_provider.py | tail -n 50Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2127
🏁 Script executed:
# Find the populate_embedding_and_position_groups function definition
rg "def populate_embedding_and_position_groups" -A 15Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1460
🏁 Script executed:
# Check the full context of _get_pg_collections_from_grids to understand the guard
rg "_get_pg_collections_from_grids" -A 40Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 15136
Call populate_embedding_and_position_groups() on all ranks in the same global order.
populate_embedding_and_position_groups() invokes torch.distributed.new_group(), which requires participation from all ranks in the default process group in the same global order, regardless of membership in the created subgroup. The current code guards this call inside the module participation check (line 201), skipping it on non-participating ranks. With multiple modules and disjoint rank sets, this breaks collective ordering as soon as any module has PP > 1.
Move the helper call outside the participation guard and only assign the returned process group handles to ranks that actually belong to the module.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/mimo/mimo_provider.py` at line 201, The call to
populate_embedding_and_position_groups(pp_group) must be executed on all ranks
in the same global order because it invokes torch.distributed.new_group; move
the call out of the module participation conditional so every process calls
populate_embedding_and_position_groups(pp_group) in the same sequence, then only
assign its returned handles (pos_embd_pg, embd_pg) to the module-local variables
for ranks that actually participate in the module; keep the pp_group argument
and preserve use of the returned handles but ensure non-participating ranks
discard or set them to None while still calling the helper to maintain
collective ordering.
| # Build infrastructure | ||
| infra = self.build_infra() |
There was a problem hiding this comment.
Avoid rebuilding MIMO infra during distributed model construction.
setup_mimo() already builds one MimoModelInfra, this method builds another, and self.provide() builds a third. That can bind the communicator, model internals, DDP wrappers, and _grids consumers to different process-group objects for the same logical module topology.
Please thread the already-built infra into model construction, or split provide() into a private helper that accepts an MimoModelInfra instead of recreating it.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/mimo/mimo_provider.py` around lines 412 - 413,
setup_mimo() already constructs a MimoModelInfra but build_infra() is being
called again here and later in provide(), causing multiple, inconsistent
infra/communicator/DP wrappers; change provide() to accept a MimoModelInfra
(e.g., provide(self, infra: MimoModelInfra, ...)) or add a private helper (e.g.,
_provide_with_infra(self, infra, ...)) that constructs the model from the given
infra instead of calling build_infra(); then thread the existing infra from
setup_mimo() into that call and remove the extra build_infra() invocation so the
same infra/communicator/_grids are used across construction and DDP wrapping.
| def finalize_model_grads_multimodule( | ||
| model, | ||
| num_tokens=None, | ||
| pg_collection=None, | ||
| force_all_reduce=None, | ||
| *, | ||
| infra: MimoModelInfra, | ||
| module_to_grid_tuple: List[Tuple], | ||
| ): | ||
| """Finalize gradients for each module using infra.pg_collections. | ||
|
|
||
| IMPORTANT: Signature matches schedule's call pattern: | ||
| config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag) | ||
|
|
||
| The `infra` and `module_to_grid_tuple` parameters are pre-bound via partial(). | ||
| We ignore the schedule-provided `pg_collection` and use per-module PGs. | ||
|
|
||
| Args: | ||
| model: Model list (passed by schedule, ignored - we use module_to_grid_tuple). | ||
| num_tokens: Token count for gradient scaling. | ||
| pg_collection: Schedule-provided PG (ignored - we use per-module PGs). | ||
| force_all_reduce: Schedule-provided flag (ignored - per-module PGs control sync). | ||
| infra: MimoModelInfra with per-module pg_collections (keyword-only, bound via partial). | ||
| module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial). | ||
| """ | ||
| for module, grid in module_to_grid_tuple: | ||
| if module is not None and is_current_rank_in_grid(grid): | ||
| # Get the module's pg_collection from infra | ||
| # Find the module name by matching the grid | ||
| module_pg = None | ||
| for module_name, mod_grid in infra.module_to_grid_map.items(): | ||
| if mod_grid is grid: | ||
| module_pg = infra.pg_collections.get(module_name) | ||
| break | ||
|
|
||
| if module_pg is not None: | ||
| _finalize_model_grads([module], num_tokens=num_tokens, pg_collection=module_pg) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What is the current signature of megatron.core.distributed.finalize_model_grads.finalize_model_grads, and does it accept a force_all_reduce argument?
💡 Result:
The current signature of megatron.core.distributed.finalize_model_grads.finalize_model_grads is: finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None, pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None, force_all_reduce: Optional[bool] = False) Yes, it accepts a force_all_reduce argument, which is an optional boolean parameter with default value False. This function all-reduces model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across pipeline stages (if not tied), and scales gradients by num_tokens.
Citations:
🏁 Script executed:
# Find where _finalize_model_grads is defined and how it's used
rg "_finalize_model_grads" --type py -B 2 -A 5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1508
🏁 Script executed:
# Check the specific file to see if there are other calls or context
cat -n src/megatron/bridge/training/mimo_parallel_utils.py | head -220Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 9984
Forward force_all_reduce to _finalize_model_grads.
The wrapper accepts force_all_reduce from the schedule but silently discards it at line 209. Although the docstring claims this is intentional ("ignored - per-module PGs control sync"), the underlying _finalize_model_grads function accepts and processes this parameter. Dropping a schedule-controlled flag changes gradient-finalization behavior and may lead to unexpected behavior in MIMO paths. Forward the parameter to maintain consistency with the schedule's expectations.
Suggested fix
- _finalize_model_grads([module], num_tokens=num_tokens, pg_collection=module_pg)
+ _finalize_model_grads(
+ [module],
+ num_tokens=num_tokens,
+ pg_collection=module_pg,
+ force_all_reduce=force_all_reduce,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/mimo_parallel_utils.py` around lines 173 - 209,
The wrapper finalize_model_grads_multimodule currently discards the
schedule-provided force_all_reduce; update the call so the flag is forwarded
into _finalize_model_grads (i.e., call _finalize_model_grads([module],
num_tokens=num_tokens, pg_collection=module_pg,
force_all_reduce=force_all_reduce)) so schedule-controlled behavior is
preserved; locate this change inside finalize_model_grads_multimodule where
module_pg is not None and ensure the symbol force_all_reduce is passed through
unchanged from the wrapper signature that includes infra and
module_to_grid_tuple.
| import random | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from megatron.core import tensor_parallel | ||
|
|
||
| seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234 | ||
|
|
There was a problem hiding this comment.
Don't treat seed 0 as unset.
This truthy chain silently replaces cfg.seed = 0 with cfg.rng.seed or 1234, which breaks reproducibility for a valid explicit seed.
Suggested fix
- seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234
+ seed = getattr(cfg, "seed", None)
+ if seed is None:
+ seed = getattr(getattr(cfg, "rng", None), "seed", None)
+ if seed is None:
+ seed = 1234📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import random | |
| import numpy as np | |
| import torch | |
| from megatron.core import tensor_parallel | |
| seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234 | |
| import random | |
| import numpy as np | |
| import torch | |
| from megatron.core import tensor_parallel | |
| seed = getattr(cfg, "seed", None) | |
| if seed is None: | |
| seed = getattr(getattr(cfg, "rng", None), "seed", None) | |
| if seed is None: | |
| seed = 1234 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/pretrain_mimo.py` around lines 58 - 65, The seed
assignment treats 0 as unset because it relies on truthiness; change the logic
in the seed initialization so you check for None explicitly: prefer cfg.seed
when cfg has attribute 'seed' even if it's 0, otherwise fall back to
cfg.rng.seed if that attribute exists and is not None, and finally default to
1234; update the expression that defines the variable seed (the current getattr
chain) to use explicit "is not None" checks or a small helper that returns the
first non-None of cfg.seed, getattr(cfg, 'rng', None).seed, and 1234 so
reproducible seeds like 0 are preserved.
| # Broadcast loss_dict to all ranks (the last rank is the logging rank for | ||
| # W&B/TensorBoard). Use broadcast_object_list from the source rank so every | ||
| # rank ends up with the same dict — no fragile P2P or GPU-side pickle needed. | ||
| last_rank = dist.get_world_size() - 1 | ||
| my_rank = dist.get_rank() | ||
|
|
||
| # All ranks agree on which rank holds the loss (pick highest rank with data). | ||
| has_loss = 1 if loss_dict else 0 | ||
| source_tensor = torch.tensor([my_rank if has_loss else -1], dtype=torch.int32, device="cuda") | ||
| torch.distributed.all_reduce(source_tensor, op=torch.distributed.ReduceOp.MAX) | ||
| source_rank = int(source_tensor.item()) | ||
|
|
||
| # Only broadcast if the source and logging rank differ and a valid source exists. | ||
| if source_rank >= 0 and source_rank != last_rank: | ||
| obj = [loss_dict if my_rank == source_rank else None] | ||
| torch.distributed.broadcast_object_list(obj, src=source_rank) | ||
| if my_rank == last_rank: | ||
| received = obj[0] or {} | ||
| # Tensors inside the received dict carry the source rank's CUDA device; | ||
| # move them to this rank's device so training_log arithmetic works. | ||
| loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()} | ||
|
|
There was a problem hiding this comment.
Broadcast the reduced loss dict to every rank.
After broadcast_object_list(), only last_rank rewrites loss_dict; every other non-source rank still keeps {}. If source_rank == last_rank, the broadcast is skipped entirely. That leaves training_log() seeing different loss keys across ranks.
Suggested fix
- # Only broadcast if the source and logging rank differ and a valid source exists.
- if source_rank >= 0 and source_rank != last_rank:
+ if source_rank >= 0:
obj = [loss_dict if my_rank == source_rank else None]
torch.distributed.broadcast_object_list(obj, src=source_rank)
- if my_rank == last_rank:
- received = obj[0] or {}
- # Tensors inside the received dict carry the source rank's CUDA device;
- # move them to this rank's device so training_log arithmetic works.
- loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()}
+ received = obj[0] or {}
+ loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/train_mimo.py` around lines 164 - 185, The
broadcast currently only updates loss_dict on the logging rank when source_rank
!= last_rank and skips broadcast when source_rank == last_rank, leaving other
ranks with empty dicts; change the logic to always call
torch.distributed.broadcast_object_list whenever source_rank >= 0 (use
src=source_rank) so every rank receives the same object, then on every rank set
loss_dict = received_obj or {} and move any tensor values to the local CUDA
device (reuse the existing received->loss_dict assignment and tensor .cuda()
conversion logic); update references: last_rank, my_rank, source_rank,
loss_dict, and broadcast_object_list in train_mimo.py so training_log sees
identical loss keys on all ranks.
| # Evaluation at specified intervals | ||
| if ( | ||
| train_config.eval_interval is not None | ||
| and train_state.step % train_config.eval_interval == 0 | ||
| and valid_data_iterator is not None | ||
| ): |
There was a problem hiding this comment.
Read eval_interval from cfg.validation.
The rest of the setup/data path uses cfg.validation.*. Looking under cfg.train here means the MIMO loop can skip evaluation or fail on configs that don't mirror the field into train.
Suggested fix
- if (
- train_config.eval_interval is not None
- and train_state.step % train_config.eval_interval == 0
- and valid_data_iterator is not None
- ):
+ if (
+ cfg.validation.eval_interval is not None
+ and train_state.step % cfg.validation.eval_interval == 0
+ and valid_data_iterator is not None
+ ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/train_mimo.py` around lines 397 - 402, The
condition uses train_config.eval_interval but the project standard is
cfg.validation.*; update the evaluation gating in the MIMO loop to read
eval_interval from the validation config (e.g., use validation_cfg =
cfg.validation or cfg.validation.eval_interval) instead of
train_config.eval_interval, and change the check to use
validation_cfg.eval_interval (ensure the variable is in scope) in the clause
that currently references train_state.step % train_config.eval_interval and the
None check, leaving train_state.step and valid_data_iterator as-is.
Summary
Adds MiMo (Multi-Input Multi-Output) training support to Megatron-Bridge, enabling heterogeneous multi-modal model training with independent per-module parallelism.
pretrain_mimo/train_mimo/mimo_stepentry points for MiMo-aware training with per-module forward/backward orchestrationmimo_parallel_utils)Phase 5 (checkpointing, evaluation, e2e tests) is stacked in a follow-up PR.
Validation
Stack
Summary by CodeRabbit
New Features
Refactor