Skip to content

feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism#2869

Open
aroshanghias-nvd wants to merge 6 commits intomainfrom
mimo/phase4-training-rebuild
Open

feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism#2869
aroshanghias-nvd wants to merge 6 commits intomainfrom
mimo/phase4-training-rebuild

Conversation

@aroshanghias-nvd
Copy link
Copy Markdown
Contributor

@aroshanghias-nvd aroshanghias-nvd commented Mar 18, 2026

Summary

Adds MiMo (Multi-Input Multi-Output) training support to Megatron-Bridge, enabling heterogeneous multi-modal model training with independent per-module parallelism.

  • MimoModelProvider: ModuleSpec-based model construction with heterogeneous LLaVA support (language model + vision encoder with independent configs)
  • Training loop: pretrain_mimo / train_mimo / mimo_step entry points for MiMo-aware training with per-module forward/backward orchestration
  • Heterogeneous parallelism: Each module (LLM, vision encoder, etc.) can run with its own TP/PP/DP configuration on a disjoint set of ranks (mimo_parallel_utils)
  • Data loading: MiMo-aware collation, dataset, and data loader dispatch routing for multi-modal inputs
  • DDP wrapping: Per-module distributed data parallel with rank-aware grid assignment
  • Megatron-LM submodule pinned to PR #3212 head
  • Full unit test coverage (122 tests)

Phase 5 (checkpointing, evaluation, e2e tests) is stacked in a follow-up PR.

Validation

  • 122 unit tests passed (models/mimo, training/mimo, data/mimo)

Stack

  • PR1 (this): Phase 4 — training, model, data, parallelism
  • PR2: Phase 5 — checkpoint save/resume, evaluation, e2e tests

Summary by CodeRabbit

  • New Features

    • Added MIMO (Multi-Instance Model Optimization) training framework enabling heterogeneous multi-module parallelism with dedicated pretraining entry point and training pipeline
    • Enhanced data loading infrastructure for multi-module models with loss masking support
  • Refactor

    • Reorganized model infrastructure, process group management, and distributed utilities for improved multi-module training efficiency and flexibility

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 18, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aroshanghias-nvd aroshanghias-nvd changed the title feat(mimo): MiMo training, model/provider, data loading, heterogeneous parallelism feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism Mar 18, 2026
…rallelism

Squash of all Phase 4 MiMo work from mimo/phase4-training (47870e4),
rebased onto upstream/main at f1fb06a.

Includes:
- MimoModelProvider with ModuleSpec-based API and heterogeneous LLaVA support
- MiMo training loop (pretrain_mimo, train_mimo, mimo_step)
- Heterogeneous TP/PP/DP parallelism plumbing (mimo_parallel_utils)
- MiMo data loading (collate, dataset, loaders, hf_provider, mock_provider)
- Data loader dispatch routing for MIMO models (loaders.py)
- MiMo DDP wrapping and model builder
- Kamran's loss mask and heterogeneous LLaVA dataset support
- Megatron-LM submodule pinned to PR #3212 head
- Full unit test coverage (provider, config, step, collate, pretrain tests)

Phase 5 (checkpointing/evaluation) is stacked in a separate branch.

Original commit history preserved in backup/mimo-phase4-training-v0 (47870e4).

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase4-training-rebuild branch from ee23945 to 9642406 Compare March 18, 2026 01:28
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work labels Mar 19, 2026
@yashaswikarnati yashaswikarnati force-pushed the mimo/phase4-training-rebuild branch from 0aea6d8 to 9642406 Compare March 25, 2026 06:50
yashaswikarnati and others added 3 commits March 25, 2026 09:47
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
@liding-nv liding-nv requested review from a team, erhoo82 and malay-nagda as code owners March 25, 2026 22:03
@aroshanghias-nvd aroshanghias-nvd changed the base branch from mimo/upstream-base-for-rebuild to main March 25, 2026 22:20
@aroshanghias-nvd aroshanghias-nvd changed the base branch from main to mimo/upstream-base-for-rebuild March 25, 2026 22:21
@aroshanghias-nvd aroshanghias-nvd changed the base branch from mimo/upstream-base-for-rebuild to main March 25, 2026 22:21
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
@liding-nv
Copy link
Copy Markdown
Contributor

/ok to test 5e0b771

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 26, 2026

📝 Walkthrough

Walkthrough

This pull request introduces comprehensive MIMO heterogeneous parallel training infrastructure, including refactored data loading with MIMO-specific paths, new abstract dataset provider interfaces, loss mask tensor handling, updated distributed grid and process group management, and new training entrypoints with multi-module gradient synchronization support.

Changes

Cohort / File(s) Summary
Submodule & Data Loading Foundation
3rdparty/Megatron-LM, src/megatron/bridge/data/loaders.py
Updated Megatron-LM pointer; refactored build_train_valid_test_data_loaders to add direct MIMO path with rank synchronization via all_reduce(MAX), simplified generic path by removing helper dispatch functions.
MIMO Dataset Provider Interface
src/megatron/bridge/data/mimo/base_provider.py, src/megatron/bridge/data/mimo/__init__.py
Added abstract MimoDatasetProvider class with build_datasets() and get_collate_fn() required methods; updated module exports to surface new base class.
MIMO Collation & Dataset Processing
src/megatron/bridge/data/mimo/collate.py, src/megatron/bridge/data/mimo/dataset.py
Extended mimo_collate_fn to stack loss_mask tensors; modified MimoDataset.__getitem__ to compute next-token labels with shifted positions, compute loss masks excluding padding and modality tokens, and return masked label/loss_mask pairs.
MIMO Data Provider Implementations
src/megatron/bridge/data/mimo/hf_provider.py, src/megatron/bridge/data/mimo/mock_provider.py
Updated both providers to inherit from MimoDatasetProvider; added trust_remote_code, hf_data_files, preprocess_fn configuration fields to HFMimoDatasetProvider; added tokenizer validation guard in MockMimoProvider.
MIMO Data Utilities & Loaders
src/megatron/bridge/data/mimo/dp_utils.py, src/megatron/bridge/data/mimo/loaders.py
Replaced MimoDpInfo dataclass with tuple return (dp_rank, dp_size, needs_data, loader_module); updated get_mimo_dp_info signature to require mimo_cfg parameter and use MIMO_LANGUAGE_MODULE_KEY constant; modified build_mimo_data_loaders to require pre-built _grids and raise on missing infrastructure.
Model Grid & Topology Management
src/megatron/bridge/models/mimo/mimo_builder.py, src/megatron/bridge/models/mimo/mimo_config.py
Removed _default_topology; renamed create_embedding_and_position_groups to populate_embedding_and_position_groups; added is_pp_first_stage/is_pp_last_stage helpers; replaced "llm" hardcoding with MIMO_LANGUAGE_MODULE_KEY; added composite process groups ["tp", "pp"], ["tp", "ep", "pp"], ["dp", "ep"]; added _validate_parallelism_constraints() for TP/DP cross-module compatibility; changed finalize() signature to require non-optional world_size.
MIMO Model Provider & DDP
src/megatron/bridge/models/mimo/mimo_provider.py, src/megatron/bridge/models/mimo/mimo_ddp.py
Refactored from cached to non-cached infra via build_infra() storing _grids; added topology and module_output_ndim overrides; replaced mixed_precision_wrapper with direct dtype casting; changed initialize_model_parallel() to raise NotImplementedError; updated DDP wrapping to use MIMO_LANGUAGE_MODULE_KEY and handle missing grid entries; added module_output_ndim field to MimoModelInfra.
MIMO Training Entrypoints & Step Functions
src/megatron/bridge/training/mimo_step.py, src/megatron/bridge/training/pretrain_mimo.py, src/megatron/bridge/training/train_mimo.py
Added new training modules: mimo_step.py with loss_func, get_batch, forward_step for pipeline schedules with loss mask support; pretrain_mimo.py with setup_mimo() initialization and pretrain_mimo() driver; train_mimo.py with train_step_mimo() and train_mimo() main loop using MultiModulePipelineCommunicator and per-module scheduler advancement.
MIMO Training Utilities
src/megatron/bridge/training/mimo_parallel_utils.py
Added comprehensive utilities: unwrap_mimo_model(), is_current_rank_in_grid(), get_module_to_grid_tuple(); build_pg_collection_for_schedule() for schedule compatibility; multimodule_no_sync(), finalize_model_grads_multimodule(), zero_grad_buffer_for_multimodule() for gradient management; validators validate_no_stub_ranks() and validate_data_loader_contract().
General Training Utilities
src/megatron/bridge/training/utils/train_utils.py
Updated training_log() to accept optional pg_collection parameter; made MoE/MTP configuration handling defensive via getattr() with defaults; guarded FLOP computation on model attribute presence.
MIMO Model Provider Tests
tests/unit_tests/models/mimo/test_llava_provider.py, tests/unit_tests/models/mimo/test_mimo_builder.py, tests/unit_tests/models/mimo/test_mimo_provider.py
Updated test module key references from "llm" to "language"; removed _default_topology test coverage; added is_pp_first_stage/is_pp_last_stage tests; replaced get_or_build_infra() caching test with initialize_model_parallel() exception test; added composite process group assertions.
MIMO Data & DDP Tests
tests/unit_tests/data/mimo/test_*.py, tests/unit_tests/models/mimo/test_mimo_ddp.py
Updated data provider/loader tests to use MimoParallelismConfig in get_mimo_dp_info calls; changed return type assertions from dataclass fields to tuple values; replaced "llm" with "language" module key; added test for missing _grids ValueError; added test for modality submodule skipping in DDP wrapping; updated collation and dataset tests for loss_mask handling.
MIMO Training Tests
tests/unit_tests/training/mimo/test_*.py
Added new test modules: test_mimo_parallel_utils.py validating rank participation, stub rank detection, contract validation, no_sync context, and grad buffer zeroing; test_mimo_step.py covering loss/batch/forward_step logic; test_pretrain_mimo.py validating seed initialization and pretrain wiring; test_mimo_config.py updated with "language" key; added __init__.py.

Sequence Diagrams

sequenceDiagram
    participant Client as Training Client
    participant DataLoader as MIMO DataLoader
    participant Provider as MimoDatasetProvider
    participant Dataset as MimoDataset
    participant Collate as mimo_collate_fn
    participant Trainer as Trainer

    Client->>DataLoader: build_mimo_data_loaders()
    DataLoader->>Provider: build_datasets(context)
    Provider-->>DataLoader: (train_ds, valid_ds, test_ds)
    DataLoader->>Collate: (batches from DataLoader)
    Collate->>Dataset: Stack loss_mask, input_ids, labels, attention_mask
    Collate-->>DataLoader: Collated batch dict
    DataLoader-->>Client: (train_loader, valid_loader, test_loader)
    
    Client->>Trainer: train_mimo(model, optimizer, iterators)
    Trainer->>Trainer: train_step_mimo()
    Trainer->>Trainer: multimodule_no_sync() context
    Trainer->>Trainer: forward_backward_pipelining()
    Trainer->>Trainer: finalize_model_grads_multimodule()
    Trainer->>Trainer: optimizer.step()
    Trainer-->>Client: loss_dict, grad_norm
Loading
sequenceDiagram
    participant Rank as Distributed Rank
    participant MIMO as MimoModel
    participant ForwardStep as forward_step()
    participant LossFn as loss_func()
    participant DataIter as DataIterator
    participant Alloc as Memory/Alloc

    Rank->>ForwardStep: (state, data_iter, model)
    ForwardStep->>MIMO: unwrap_mimo_model()
    MIMO-->>ForwardStep: model ref
    ForwardStep->>DataIter: Check if rank needs data
    alt Rank needs data
        ForwardStep->>DataIter: get_batch()
        DataIter-->>ForwardStep: {loss_mask, input_ids, labels, ...}
    else Rank doesn't need data
        ForwardStep->>Alloc: Create null placeholders
        Alloc-->>ForwardStep: {input_ids: None, ...}
    end
    ForwardStep->>MIMO: model(**batch)
    MIMO-->>ForwardStep: output_tensor
    alt Last PP stage
        ForwardStep->>LossFn: Return partial(loss_func, loss_mask)
    else Intermediate stage
        ForwardStep-->>Rank: (output_tensor, None)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

ready-to-merge

Suggested reviewers

  • erhoo82
  • malay-nagda
  • yaoyu-33
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Test Results For Major Changes ❓ Inconclusive PR mentions 122 unit tests passed but provides no formal test results documentation, convergence validation, or performance benchmarking. Integration/end-to-end tests deferred to Phase 5. Include formal test results report, convergence metrics for training changes, performance benchmarks if applicable, and document testing scope with Phase 5 limitations.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(mimo): Phase 4 - MiMo training, model/provider, data loading, heterogeneous parallelism' directly and comprehensively summarizes the main change: adding Phase 4 of MiMo (Multi-Input Multi-Output) training support with model/provider infrastructure, data loading, and heterogeneous parallelism support.
Docstring Coverage ✅ Passed Docstring coverage is 89.51% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch mimo/phase4-training-rebuild
⚔️ Resolve merge conflicts
  • Resolve merge conflict in branch mimo/phase4-training-rebuild

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/unit_tests/models/mimo/test_mimo_builder.py (1)

122-125: ⚠️ Potential issue | 🟡 Minor

Stale comment: "llm" should be "language".

The comment still references "llm" but the module key was changed to "language".

📝 Suggested fix
-        # First call (llm): shape [8, 1, 1, 2, 1]
+        # First call (language): shape [8, 1, 1, 2, 1]
         first_call_kwargs = mock_grid_class.call_args_list[0][1]
         assert first_call_kwargs["shape"] == [8, 1, 1, 2, 1]
         assert first_call_kwargs["rank_offset"] == 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/mimo/test_mimo_builder.py` around lines 122 - 125,
Update the stale inline comment that reads "# First call (llm): shape [8, 1, 1,
2, 1]" to reference the new module key name "language" instead of "llm"; locate
the comment near the assertions using first_call_kwargs and
mock_grid_class.call_args_list in test_mimo_builder.py and change it to
something like "# First call (language): shape [8, 1, 1, 2, 1]" so the comment
matches the updated module key.
🧹 Nitpick comments (8)
tests/unit_tests/data/mimo/test_dp_utils.py (1)

78-81: Consider underscore prefix for unused unpacked variables.

In tests that don't assert on dp_rank and dp_size, consider using underscores to silence Ruff warnings and signal intent.

♻️ Suggested fix (optional)
-    dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids)
+    _dp_rank, _dp_size, needs_data, loader_module = get_mimo_dp_info(mimo_cfg, grids)

Apply similar changes at lines 94, 110, and 126 where dp_rank/dp_size are not asserted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/data/mimo/test_dp_utils.py` around lines 78 - 81, The test
unpacks get_mimo_dp_info into dp_rank and dp_size but never uses them, causing
lint warnings; change the unused variables to _dp_rank and _dp_size (or simply _
and _dp_size) in the test(s) that don't assert them (the unpack in
tests/unit_tests/data/mimo/test_dp_utils.py around the get_mimo_dp_info calls at
the locations mentioned) so the intent is clear and Ruff warnings are silenced
while keeping loader_module and needs_data assertions unchanged.
src/megatron/bridge/data/mimo/loaders.py (1)

83-83: Unused variable: prefix loader_module with underscore.

The loader_module value is unpacked but never used. Per static analysis hint (RUF059), prefix it with an underscore.

♻️ Suggested fix
-    dp_rank, dp_size, needs_data, loader_module = get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids)
+    dp_rank, dp_size, needs_data, _loader_module = get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/loaders.py` at line 83, The unpacked variable
loader_module from the call to
get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) is unused; rename it
to _loader_module (or prefix with an underscore) to satisfy the unused-variable
lint rule (RUF059). Update the assignment dp_rank, dp_size, needs_data,
loader_module to dp_rank, dp_size, needs_data, _loader_module wherever that
pattern appears (e.g., in the function or module that calls get_mimo_dp_info) so
the unused value is clearly marked.
src/megatron/bridge/data/mimo/hf_provider.py (1)

8-8: Prefer modern union syntax over Union and Optional.

Per coding guidelines, use X | Y for union types instead of Union[X, Y], and T | None instead of Optional[T].

♻️ Suggested fix
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Tuple

Then update line 77:

-    hf_data_files: Optional[Union[str, List[str]]] = None
+    hf_data_files: str | List[str] | None = None

As per coding guidelines: "Use X | Y for union types instead of Union[X, Y]" and "Use T | None for nullable types instead of Optional[T]".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/hf_provider.py` at line 8, The file imports
Optional and Union from typing; per guidelines remove Optional and Union from
the import list and replace their uses with modern PEP 604 syntax (T | None for
optional types and A | B for unions). Update the import line in hf_provider.py
to drop Optional and Union and then replace every type annotation in this module
(including the annotation referenced around line 77) that uses Union[...] or
Optional[...] to use the pipe syntax instead so function/method signatures and
return types (wherever present in this module) use X | Y and T | None.
tests/unit_tests/training/mimo/test_mimo_step.py (2)

30-52: Prefix unused metrics variable with underscore to satisfy linter.

The metrics variable is unpacked but not used in these tests. Use _ prefix to indicate intentional non-use.

♻️ Suggested fix
-        total_loss, num_tokens, metrics = loss_func(loss_mask, output_tensor)
+        total_loss, num_tokens, _metrics = loss_func(loss_mask, output_tensor)

Apply similarly to line 49.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/training/mimo/test_mimo_step.py` around lines 30 - 52, In
the two tests test_loss_with_all_ones_mask and test_loss_with_all_zeros_mask,
the unpacked third return value from loss_func is stored in an unused variable
named metrics; rename it to _metrics (or simply _ ) to satisfy the linter and
indicate intentional non-use—update the unpacking in both tests where
loss_func(loss_mask, output_tensor) is called so it reads e.g. total_loss,
num_tokens, _metrics = loss_func(...)

111-111: Prefix unused output variable with underscore.

Per static analysis, the output variable is unpacked but unused.

♻️ Suggested fix
-        output, loss_fn = forward_step(mock_state, data_iter, mock_model)
+        _output, loss_fn = forward_step(mock_state, data_iter, mock_model)

Apply similarly to line 138.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/training/mimo/test_mimo_step.py` at line 111, The unpacked
variable "output" from the call to forward_step(mock_state, data_iter,
mock_model) is unused; rename it to "_output" to satisfy static analysis and
avoid unused-variable warnings, and make the same change for the other
forward_step unpack at the other location where "output" is unused (apply the
same "_output" prefix there); ensure you only rename the local variable, not the
call or other identifiers.
tests/unit_tests/training/mimo/test_pretrain_mimo.py (1)

88-88: Inconsistent module key: "llm" vs "language".

This test uses "llm" as the module key in module_to_grid_map, but the rest of this PR consistently uses "language" (which matches MIMO_LANGUAGE_MODULE_KEY). While the seed setting logic iterates over all keys and doesn't require a specific name, using "language" would be more consistent.

♻️ Suggested fix
-    mimo_infra = SimpleNamespace(module_to_grid_map={"llm": grid})
+    mimo_infra = SimpleNamespace(module_to_grid_map={"language": grid})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/training/mimo/test_pretrain_mimo.py` at line 88, Replace the
inconsistent module key "llm" in the test's mimo_infra.module_to_grid_map with
the canonical key used throughout the PR ("language" / MIMO_LANGUAGE_MODULE_KEY)
so the test matches the rest of the codebase; update the SimpleNamespace
initialization (mimo_infra = SimpleNamespace(module_to_grid_map={"llm": grid}))
to use "language" as the map key to align with the seed-setting logic and naming
elsewhere.
src/megatron/bridge/data/mimo/base_provider.py (1)

6-8: Use Python 3.10+ type hint syntax.

Per coding guidelines, use T | None instead of Optional[T] and built-in tuple instead of typing.Tuple for Python 3.10+.

♻️ Suggested fix
-from typing import Callable, Optional, Tuple
+from typing import Callable

And update the return type on line 41:

-    ) -> Tuple[Optional[Dataset], Optional[Dataset], Optional[Dataset]]:
+    ) -> tuple[Dataset | None, Dataset | None, Dataset | None]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/base_provider.py` around lines 6 - 8, Replace
typing.Optional and typing.Tuple usages with Python 3.10+ syntax: remove
Optional and Tuple from imports and change annotations like Optional[T] to T |
None and Tuple[A, B] to tuple[A, B]; update any Callable[...] imports/uses to
remain as-is if needed. In particular, update the return type annotation
referenced on line 41 to use the new native syntax (e.g., change Optional[X] to
X | None or Tuple[...] to tuple[...]) and ensure imports only include
abstractmethod and dataclass (and Callable if still used).
src/megatron/bridge/data/mimo/dp_utils.py (1)

6-6: Use built-in tuple for type hints per Python 3.10+ guidelines.

♻️ Suggested fix
-from typing import TYPE_CHECKING, Dict, Tuple
+from typing import TYPE_CHECKING

And update the return type annotation on line 21:

-) -> Tuple[int, int, bool, str]:
+) -> tuple[int, int, bool, str]:

Note: Dict import can also be removed if the type hint on line 20 uses the built-in dict instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/dp_utils.py` at line 6, Replace typing imports
Dict and Tuple with built-in types: remove Dict and Tuple from the import line
(keep TYPE_CHECKING) and update all type hints that use Dict[...] and Tuple[...]
to use built-in dict[...] and tuple[...] instead; specifically change the return
type annotation on the function currently annotated with Tuple/Dict to use tuple
and dict, and remove the now-unused Dict/Tuple imports in
src/megatron/bridge/data/mimo/dp_utils.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/data/loaders.py`:
- Around line 205-207: The do_valid and do_test flags are using
cfg.train.eval_iters but the module reads evaluation settings from
cfg.validation; update the checks in the block that sets
do_train/do_valid/do_test so that do_valid and do_test use
cfg.validation.eval_iters (or the appropriate cfg.validation field) instead of
cfg.train.eval_iters, keeping do_train unchanged; adjust references to the
variables do_valid and do_test in the surrounding code if needed to ensure
evaluation behavior follows cfg.validation.eval_iters.

In `@src/megatron/bridge/data/mimo/collate.py`:
- Around line 62-63: The docstring for mimo_collate_fn is missing documentation
for the newly collated "loss_mask" field; update the mimo_collate_fn docstring
(Args, Returns and example sections) to list loss_mask as a required tensor in
the collated payload, describe its shape and purpose (e.g., mask for loss
computation), and show it in the example return structure alongside existing
keys so external callers see the API contract.

In `@src/megatron/bridge/models/mimo/mimo_provider.py`:
- Around line 412-413: setup_mimo() already constructs a MimoModelInfra but
build_infra() is being called again here and later in provide(), causing
multiple, inconsistent infra/communicator/DP wrappers; change provide() to
accept a MimoModelInfra (e.g., provide(self, infra: MimoModelInfra, ...)) or add
a private helper (e.g., _provide_with_infra(self, infra, ...)) that constructs
the model from the given infra instead of calling build_infra(); then thread the
existing infra from setup_mimo() into that call and remove the extra
build_infra() invocation so the same infra/communicator/_grids are used across
construction and DDP wrapping.
- Line 201: The call to populate_embedding_and_position_groups(pp_group) must be
executed on all ranks in the same global order because it invokes
torch.distributed.new_group; move the call out of the module participation
conditional so every process calls
populate_embedding_and_position_groups(pp_group) in the same sequence, then only
assign its returned handles (pos_embd_pg, embd_pg) to the module-local variables
for ranks that actually participate in the module; keep the pp_group argument
and preserve use of the returned handles but ensure non-participating ranks
discard or set them to None while still calling the helper to maintain
collective ordering.

In `@src/megatron/bridge/training/mimo_parallel_utils.py`:
- Around line 173-209: The wrapper finalize_model_grads_multimodule currently
discards the schedule-provided force_all_reduce; update the call so the flag is
forwarded into _finalize_model_grads (i.e., call _finalize_model_grads([module],
num_tokens=num_tokens, pg_collection=module_pg,
force_all_reduce=force_all_reduce)) so schedule-controlled behavior is
preserved; locate this change inside finalize_model_grads_multimodule where
module_pg is not None and ensure the symbol force_all_reduce is passed through
unchanged from the wrapper signature that includes infra and
module_to_grid_tuple.

In `@src/megatron/bridge/training/pretrain_mimo.py`:
- Around line 58-65: The seed assignment treats 0 as unset because it relies on
truthiness; change the logic in the seed initialization so you check for None
explicitly: prefer cfg.seed when cfg has attribute 'seed' even if it's 0,
otherwise fall back to cfg.rng.seed if that attribute exists and is not None,
and finally default to 1234; update the expression that defines the variable
seed (the current getattr chain) to use explicit "is not None" checks or a small
helper that returns the first non-None of cfg.seed, getattr(cfg, 'rng',
None).seed, and 1234 so reproducible seeds like 0 are preserved.

In `@src/megatron/bridge/training/train_mimo.py`:
- Around line 164-185: The broadcast currently only updates loss_dict on the
logging rank when source_rank != last_rank and skips broadcast when source_rank
== last_rank, leaving other ranks with empty dicts; change the logic to always
call torch.distributed.broadcast_object_list whenever source_rank >= 0 (use
src=source_rank) so every rank receives the same object, then on every rank set
loss_dict = received_obj or {} and move any tensor values to the local CUDA
device (reuse the existing received->loss_dict assignment and tensor .cuda()
conversion logic); update references: last_rank, my_rank, source_rank,
loss_dict, and broadcast_object_list in train_mimo.py so training_log sees
identical loss keys on all ranks.
- Around line 397-402: The condition uses train_config.eval_interval but the
project standard is cfg.validation.*; update the evaluation gating in the MIMO
loop to read eval_interval from the validation config (e.g., use validation_cfg
= cfg.validation or cfg.validation.eval_interval) instead of
train_config.eval_interval, and change the check to use
validation_cfg.eval_interval (ensure the variable is in scope) in the clause
that currently references train_state.step % train_config.eval_interval and the
None check, leaving train_state.step and valid_data_iterator as-is.

---

Outside diff comments:
In `@tests/unit_tests/models/mimo/test_mimo_builder.py`:
- Around line 122-125: Update the stale inline comment that reads "# First call
(llm): shape [8, 1, 1, 2, 1]" to reference the new module key name "language"
instead of "llm"; locate the comment near the assertions using first_call_kwargs
and mock_grid_class.call_args_list in test_mimo_builder.py and change it to
something like "# First call (language): shape [8, 1, 1, 2, 1]" so the comment
matches the updated module key.

---

Nitpick comments:
In `@src/megatron/bridge/data/mimo/base_provider.py`:
- Around line 6-8: Replace typing.Optional and typing.Tuple usages with Python
3.10+ syntax: remove Optional and Tuple from imports and change annotations like
Optional[T] to T | None and Tuple[A, B] to tuple[A, B]; update any Callable[...]
imports/uses to remain as-is if needed. In particular, update the return type
annotation referenced on line 41 to use the new native syntax (e.g., change
Optional[X] to X | None or Tuple[...] to tuple[...]) and ensure imports only
include abstractmethod and dataclass (and Callable if still used).

In `@src/megatron/bridge/data/mimo/dp_utils.py`:
- Line 6: Replace typing imports Dict and Tuple with built-in types: remove Dict
and Tuple from the import line (keep TYPE_CHECKING) and update all type hints
that use Dict[...] and Tuple[...] to use built-in dict[...] and tuple[...]
instead; specifically change the return type annotation on the function
currently annotated with Tuple/Dict to use tuple and dict, and remove the
now-unused Dict/Tuple imports in src/megatron/bridge/data/mimo/dp_utils.py.

In `@src/megatron/bridge/data/mimo/hf_provider.py`:
- Line 8: The file imports Optional and Union from typing; per guidelines remove
Optional and Union from the import list and replace their uses with modern PEP
604 syntax (T | None for optional types and A | B for unions). Update the import
line in hf_provider.py to drop Optional and Union and then replace every type
annotation in this module (including the annotation referenced around line 77)
that uses Union[...] or Optional[...] to use the pipe syntax instead so
function/method signatures and return types (wherever present in this module)
use X | Y and T | None.

In `@src/megatron/bridge/data/mimo/loaders.py`:
- Line 83: The unpacked variable loader_module from the call to
get_mimo_dp_info(cfg.model.mimo_parallelism_config, grids) is unused; rename it
to _loader_module (or prefix with an underscore) to satisfy the unused-variable
lint rule (RUF059). Update the assignment dp_rank, dp_size, needs_data,
loader_module to dp_rank, dp_size, needs_data, _loader_module wherever that
pattern appears (e.g., in the function or module that calls get_mimo_dp_info) so
the unused value is clearly marked.

In `@tests/unit_tests/data/mimo/test_dp_utils.py`:
- Around line 78-81: The test unpacks get_mimo_dp_info into dp_rank and dp_size
but never uses them, causing lint warnings; change the unused variables to
_dp_rank and _dp_size (or simply _ and _dp_size) in the test(s) that don't
assert them (the unpack in tests/unit_tests/data/mimo/test_dp_utils.py around
the get_mimo_dp_info calls at the locations mentioned) so the intent is clear
and Ruff warnings are silenced while keeping loader_module and needs_data
assertions unchanged.

In `@tests/unit_tests/training/mimo/test_mimo_step.py`:
- Around line 30-52: In the two tests test_loss_with_all_ones_mask and
test_loss_with_all_zeros_mask, the unpacked third return value from loss_func is
stored in an unused variable named metrics; rename it to _metrics (or simply _ )
to satisfy the linter and indicate intentional non-use—update the unpacking in
both tests where loss_func(loss_mask, output_tensor) is called so it reads e.g.
total_loss, num_tokens, _metrics = loss_func(...)
- Line 111: The unpacked variable "output" from the call to
forward_step(mock_state, data_iter, mock_model) is unused; rename it to
"_output" to satisfy static analysis and avoid unused-variable warnings, and
make the same change for the other forward_step unpack at the other location
where "output" is unused (apply the same "_output" prefix there); ensure you
only rename the local variable, not the call or other identifiers.

In `@tests/unit_tests/training/mimo/test_pretrain_mimo.py`:
- Line 88: Replace the inconsistent module key "llm" in the test's
mimo_infra.module_to_grid_map with the canonical key used throughout the PR
("language" / MIMO_LANGUAGE_MODULE_KEY) so the test matches the rest of the
codebase; update the SimpleNamespace initialization (mimo_infra =
SimpleNamespace(module_to_grid_map={"llm": grid})) to use "language" as the map
key to align with the seed-setting logic and naming elsewhere.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c2643fa1-0ac5-41ab-a44a-06d77567647a

📥 Commits

Reviewing files that changed from the base of the PR and between 3e2f5fa and 5e0b771.

📒 Files selected for processing (34)
  • 3rdparty/Megatron-LM
  • src/megatron/bridge/data/loaders.py
  • src/megatron/bridge/data/mimo/__init__.py
  • src/megatron/bridge/data/mimo/base_provider.py
  • src/megatron/bridge/data/mimo/collate.py
  • src/megatron/bridge/data/mimo/dataset.py
  • src/megatron/bridge/data/mimo/dp_utils.py
  • src/megatron/bridge/data/mimo/hf_provider.py
  • src/megatron/bridge/data/mimo/loaders.py
  • src/megatron/bridge/data/mimo/mock_provider.py
  • src/megatron/bridge/models/mimo/llava_provider.py
  • src/megatron/bridge/models/mimo/mimo_builder.py
  • src/megatron/bridge/models/mimo/mimo_config.py
  • src/megatron/bridge/models/mimo/mimo_ddp.py
  • src/megatron/bridge/models/mimo/mimo_provider.py
  • src/megatron/bridge/training/mimo_parallel_utils.py
  • src/megatron/bridge/training/mimo_step.py
  • src/megatron/bridge/training/pretrain_mimo.py
  • src/megatron/bridge/training/train_mimo.py
  • src/megatron/bridge/training/utils/train_utils.py
  • tests/unit_tests/data/mimo/test_collate.py
  • tests/unit_tests/data/mimo/test_dataset.py
  • tests/unit_tests/data/mimo/test_dp_utils.py
  • tests/unit_tests/data/mimo/test_hf_provider.py
  • tests/unit_tests/data/mimo/test_loaders.py
  • tests/unit_tests/models/mimo/test_llava_provider.py
  • tests/unit_tests/models/mimo/test_mimo_builder.py
  • tests/unit_tests/models/mimo/test_mimo_ddp.py
  • tests/unit_tests/models/mimo/test_mimo_provider.py
  • tests/unit_tests/training/mimo/__init__.py
  • tests/unit_tests/training/mimo/test_mimo_config.py
  • tests/unit_tests/training/mimo/test_mimo_parallel_utils.py
  • tests/unit_tests/training/mimo/test_mimo_step.py
  • tests/unit_tests/training/mimo/test_pretrain_mimo.py

Comment on lines +205 to +207
do_train = train_dataloader is not None and cfg.train.train_iters > 0
do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0
do_test = test_dataloader is not None and cfg.train.eval_iters > 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the validation config here.

The rest of this module reads evaluation settings from cfg.validation; using cfg.train.eval_iters on the MIMO path can raise or permanently disable do_valid/do_test.

Suggested fix
-        do_valid = valid_dataloader is not None and cfg.train.eval_iters > 0
-        do_test = test_dataloader is not None and cfg.train.eval_iters > 0
+        do_valid = valid_dataloader is not None and cfg.validation.eval_iters > 0
+        do_test = test_dataloader is not None and cfg.validation.eval_iters > 0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/loaders.py` around lines 205 - 207, The do_valid and
do_test flags are using cfg.train.eval_iters but the module reads evaluation
settings from cfg.validation; update the checks in the block that sets
do_train/do_valid/do_test so that do_valid and do_test use
cfg.validation.eval_iters (or the appropriate cfg.validation field) instead of
cfg.train.eval_iters, keeping do_train unchanged; adjust references to the
variables do_valid and do_test in the surrounding code if needed to ensure
evaluation behavior follows cfg.validation.eval_iters.

Comment on lines +62 to 63
loss_mask = torch.stack([item["loss_mask"] for item in batch])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update mimo_collate_fn docstring to include loss_mask.

Line 62 and Line 114 make loss_mask part of the required collated payload, but the Args/Returns/example contract still omits it.

📝 Suggested doc update
     Args:
         batch: List of examples from MimoDataset, each containing:
             - input_ids: Token IDs with placeholder tokens
             - labels: Labels for causal LM training
+            - loss_mask: Per-token loss mask
             - attention_mask: Attention mask
             - position_ids: Position indices
             - modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
@@
     Returns:
         Dict containing:
             - input_ids: (batch, seq) stacked token IDs
             - labels: (batch, seq) stacked labels
+            - loss_mask: (batch, seq) stacked per-token loss mask
             - attention_mask: (batch, seq) attention mask
             - position_ids: (batch, seq) position indices
             - modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors

As per coding guidelines For interfaces that may be used outside a file, prefer docstrings over comments.

Also applies to: 114-114

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/mimo/collate.py` around lines 62 - 63, The docstring
for mimo_collate_fn is missing documentation for the newly collated "loss_mask"
field; update the mimo_collate_fn docstring (Args, Returns and example sections)
to list loss_mask as a required tensor in the collated payload, describe its
shape and purpose (e.g., mask for loss computation), and show it in the example
return structure alongside existing keys so external callers see the API
contract.


# Create embedding groups for PP > 1 (collective operation on all PP ranks)
pos_embd_pg, embd_pg = create_embedding_and_position_groups(pp_group)
pos_embd_pg, embd_pg = populate_embedding_and_position_groups(pp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

Does torch.distributed.new_group need to be called by every rank in the default process group, including non-members, and in the same global order?

💡 Result:

Yes, torch.distributed.new_group must be called by every rank in the default process group (the world group), including non-members of the new subgroup. This is required for synchronization via an internal store-based barrier (_store_based_barrier) that ensures consistent global metadata across all ranks about the created process groups. Non-members receive a special NON_GROUP_MEMBER handle (id -100) but still participate in the barrier. It must also be called in the same global order across all ranks to avoid deadlocks or hangs, especially when creating multiple overlapping groups. Divergent order among participants of individual groups can still cause issues in some backends like Gloo, even if participant-local order is consistent. With use_local_synchronization=True, non-participants can skip, but global order consistency is still advised to prevent backend-specific hangs.

Citations:


🏁 Script executed:

# First, let's examine the relevant section of mimo_provider.py around line 201
head -n 230 src/megatron/bridge/models/mimo/mimo_provider.py | tail -n 50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2127


🏁 Script executed:

# Find the populate_embedding_and_position_groups function definition
rg "def populate_embedding_and_position_groups" -A 15

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1460


🏁 Script executed:

# Check the full context of _get_pg_collections_from_grids to understand the guard
rg "_get_pg_collections_from_grids" -A 40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 15136


Call populate_embedding_and_position_groups() on all ranks in the same global order.

populate_embedding_and_position_groups() invokes torch.distributed.new_group(), which requires participation from all ranks in the default process group in the same global order, regardless of membership in the created subgroup. The current code guards this call inside the module participation check (line 201), skipping it on non-participating ranks. With multiple modules and disjoint rank sets, this breaks collective ordering as soon as any module has PP > 1.

Move the helper call outside the participation guard and only assign the returned process group handles to ranks that actually belong to the module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/mimo/mimo_provider.py` at line 201, The call to
populate_embedding_and_position_groups(pp_group) must be executed on all ranks
in the same global order because it invokes torch.distributed.new_group; move
the call out of the module participation conditional so every process calls
populate_embedding_and_position_groups(pp_group) in the same sequence, then only
assign its returned handles (pos_embd_pg, embd_pg) to the module-local variables
for ranks that actually participate in the module; keep the pp_group argument
and preserve use of the returned handles but ensure non-participating ranks
discard or set them to None while still calling the helper to maintain
collective ordering.

Comment on lines +412 to +413
# Build infrastructure
infra = self.build_infra()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid rebuilding MIMO infra during distributed model construction.

setup_mimo() already builds one MimoModelInfra, this method builds another, and self.provide() builds a third. That can bind the communicator, model internals, DDP wrappers, and _grids consumers to different process-group objects for the same logical module topology.

Please thread the already-built infra into model construction, or split provide() into a private helper that accepts an MimoModelInfra instead of recreating it.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/mimo/mimo_provider.py` around lines 412 - 413,
setup_mimo() already constructs a MimoModelInfra but build_infra() is being
called again here and later in provide(), causing multiple, inconsistent
infra/communicator/DP wrappers; change provide() to accept a MimoModelInfra
(e.g., provide(self, infra: MimoModelInfra, ...)) or add a private helper (e.g.,
_provide_with_infra(self, infra, ...)) that constructs the model from the given
infra instead of calling build_infra(); then thread the existing infra from
setup_mimo() into that call and remove the extra build_infra() invocation so the
same infra/communicator/_grids are used across construction and DDP wrapping.

Comment on lines +173 to +209
def finalize_model_grads_multimodule(
model,
num_tokens=None,
pg_collection=None,
force_all_reduce=None,
*,
infra: MimoModelInfra,
module_to_grid_tuple: List[Tuple],
):
"""Finalize gradients for each module using infra.pg_collections.

IMPORTANT: Signature matches schedule's call pattern:
config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag)

The `infra` and `module_to_grid_tuple` parameters are pre-bound via partial().
We ignore the schedule-provided `pg_collection` and use per-module PGs.

Args:
model: Model list (passed by schedule, ignored - we use module_to_grid_tuple).
num_tokens: Token count for gradient scaling.
pg_collection: Schedule-provided PG (ignored - we use per-module PGs).
force_all_reduce: Schedule-provided flag (ignored - per-module PGs control sync).
infra: MimoModelInfra with per-module pg_collections (keyword-only, bound via partial).
module_to_grid_tuple: List of (module, grid) tuples (keyword-only, bound via partial).
"""
for module, grid in module_to_grid_tuple:
if module is not None and is_current_rank_in_grid(grid):
# Get the module's pg_collection from infra
# Find the module name by matching the grid
module_pg = None
for module_name, mod_grid in infra.module_to_grid_map.items():
if mod_grid is grid:
module_pg = infra.pg_collections.get(module_name)
break

if module_pg is not None:
_finalize_model_grads([module], num_tokens=num_tokens, pg_collection=module_pg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

What is the current signature of megatron.core.distributed.finalize_model_grads.finalize_model_grads, and does it accept a force_all_reduce argument?

💡 Result:

The current signature of megatron.core.distributed.finalize_model_grads.finalize_model_grads is: finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None, pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None, force_all_reduce: Optional[bool] = False) Yes, it accepts a force_all_reduce argument, which is an optional boolean parameter with default value False. This function all-reduces model grads across DP replicas, layernorm grads for sequence parallelism, embedding grads across pipeline stages (if not tied), and scales gradients by num_tokens.

Citations:


🏁 Script executed:

# Find where _finalize_model_grads is defined and how it's used
rg "_finalize_model_grads" --type py -B 2 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1508


🏁 Script executed:

# Check the specific file to see if there are other calls or context
cat -n src/megatron/bridge/training/mimo_parallel_utils.py | head -220

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 9984


Forward force_all_reduce to _finalize_model_grads.

The wrapper accepts force_all_reduce from the schedule but silently discards it at line 209. Although the docstring claims this is intentional ("ignored - per-module PGs control sync"), the underlying _finalize_model_grads function accepts and processes this parameter. Dropping a schedule-controlled flag changes gradient-finalization behavior and may lead to unexpected behavior in MIMO paths. Forward the parameter to maintain consistency with the schedule's expectations.

Suggested fix
-                _finalize_model_grads([module], num_tokens=num_tokens, pg_collection=module_pg)
+                _finalize_model_grads(
+                    [module],
+                    num_tokens=num_tokens,
+                    pg_collection=module_pg,
+                    force_all_reduce=force_all_reduce,
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/mimo_parallel_utils.py` around lines 173 - 209,
The wrapper finalize_model_grads_multimodule currently discards the
schedule-provided force_all_reduce; update the call so the flag is forwarded
into _finalize_model_grads (i.e., call _finalize_model_grads([module],
num_tokens=num_tokens, pg_collection=module_pg,
force_all_reduce=force_all_reduce)) so schedule-controlled behavior is
preserved; locate this change inside finalize_model_grads_multimodule where
module_pg is not None and ensure the symbol force_all_reduce is passed through
unchanged from the wrapper signature that includes infra and
module_to_grid_tuple.

Comment on lines +58 to +65
import random

import numpy as np
import torch
from megatron.core import tensor_parallel

seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't treat seed 0 as unset.

This truthy chain silently replaces cfg.seed = 0 with cfg.rng.seed or 1234, which breaks reproducibility for a valid explicit seed.

Suggested fix
-    seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234
+    seed = getattr(cfg, "seed", None)
+    if seed is None:
+        seed = getattr(getattr(cfg, "rng", None), "seed", None)
+    if seed is None:
+        seed = 1234
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import random
import numpy as np
import torch
from megatron.core import tensor_parallel
seed = getattr(cfg, "seed", None) or getattr(getattr(cfg, "rng", None), "seed", None) or 1234
import random
import numpy as np
import torch
from megatron.core import tensor_parallel
seed = getattr(cfg, "seed", None)
if seed is None:
seed = getattr(getattr(cfg, "rng", None), "seed", None)
if seed is None:
seed = 1234
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/pretrain_mimo.py` around lines 58 - 65, The seed
assignment treats 0 as unset because it relies on truthiness; change the logic
in the seed initialization so you check for None explicitly: prefer cfg.seed
when cfg has attribute 'seed' even if it's 0, otherwise fall back to
cfg.rng.seed if that attribute exists and is not None, and finally default to
1234; update the expression that defines the variable seed (the current getattr
chain) to use explicit "is not None" checks or a small helper that returns the
first non-None of cfg.seed, getattr(cfg, 'rng', None).seed, and 1234 so
reproducible seeds like 0 are preserved.

Comment on lines +164 to +185
# Broadcast loss_dict to all ranks (the last rank is the logging rank for
# W&B/TensorBoard). Use broadcast_object_list from the source rank so every
# rank ends up with the same dict — no fragile P2P or GPU-side pickle needed.
last_rank = dist.get_world_size() - 1
my_rank = dist.get_rank()

# All ranks agree on which rank holds the loss (pick highest rank with data).
has_loss = 1 if loss_dict else 0
source_tensor = torch.tensor([my_rank if has_loss else -1], dtype=torch.int32, device="cuda")
torch.distributed.all_reduce(source_tensor, op=torch.distributed.ReduceOp.MAX)
source_rank = int(source_tensor.item())

# Only broadcast if the source and logging rank differ and a valid source exists.
if source_rank >= 0 and source_rank != last_rank:
obj = [loss_dict if my_rank == source_rank else None]
torch.distributed.broadcast_object_list(obj, src=source_rank)
if my_rank == last_rank:
received = obj[0] or {}
# Tensors inside the received dict carry the source rank's CUDA device;
# move them to this rank's device so training_log arithmetic works.
loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Broadcast the reduced loss dict to every rank.

After broadcast_object_list(), only last_rank rewrites loss_dict; every other non-source rank still keeps {}. If source_rank == last_rank, the broadcast is skipped entirely. That leaves training_log() seeing different loss keys across ranks.

Suggested fix
-    # Only broadcast if the source and logging rank differ and a valid source exists.
-    if source_rank >= 0 and source_rank != last_rank:
+    if source_rank >= 0:
         obj = [loss_dict if my_rank == source_rank else None]
         torch.distributed.broadcast_object_list(obj, src=source_rank)
-        if my_rank == last_rank:
-            received = obj[0] or {}
-            # Tensors inside the received dict carry the source rank's CUDA device;
-            # move them to this rank's device so training_log arithmetic works.
-            loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()}
+        received = obj[0] or {}
+        loss_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in received.items()}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/train_mimo.py` around lines 164 - 185, The
broadcast currently only updates loss_dict on the logging rank when source_rank
!= last_rank and skips broadcast when source_rank == last_rank, leaving other
ranks with empty dicts; change the logic to always call
torch.distributed.broadcast_object_list whenever source_rank >= 0 (use
src=source_rank) so every rank receives the same object, then on every rank set
loss_dict = received_obj or {} and move any tensor values to the local CUDA
device (reuse the existing received->loss_dict assignment and tensor .cuda()
conversion logic); update references: last_rank, my_rank, source_rank,
loss_dict, and broadcast_object_list in train_mimo.py so training_log sees
identical loss keys on all ranks.

Comment on lines +397 to +402
# Evaluation at specified intervals
if (
train_config.eval_interval is not None
and train_state.step % train_config.eval_interval == 0
and valid_data_iterator is not None
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Read eval_interval from cfg.validation.

The rest of the setup/data path uses cfg.validation.*. Looking under cfg.train here means the MIMO loop can skip evaluation or fail on configs that don't mirror the field into train.

Suggested fix
-        if (
-            train_config.eval_interval is not None
-            and train_state.step % train_config.eval_interval == 0
-            and valid_data_iterator is not None
-        ):
+        if (
+            cfg.validation.eval_interval is not None
+            and train_state.step % cfg.validation.eval_interval == 0
+            and valid_data_iterator is not None
+        ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/train_mimo.py` around lines 397 - 402, The
condition uses train_config.eval_interval but the project standard is
cfg.validation.*; update the evaluation gating in the MIMO loop to read
eval_interval from the validation config (e.g., use validation_cfg =
cfg.validation or cfg.validation.eval_interval) instead of
train_config.eval_interval, and change the check to use
validation_cfg.eval_interval (ensure the variable is in scope) in the clause
that currently references train_state.step % train_config.eval_interval and the
None check, leaving train_state.step and valid_data_iterator as-is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants