Skip to content

support for training qwen3 vl with dist train#2367

Open
shifangx wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
shifangx:shifang/qwen3_vl_dist_train
Open

support for training qwen3 vl with dist train#2367
shifangx wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
shifangx:shifang/qwen3_vl_dist_train

Conversation

@shifangx
Copy link
Copy Markdown
Contributor

@shifangx shifangx commented Feb 13, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

This pr is for DistTrain.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • New Features

    • Added distributed training support for vision-language models
    • Added decentralized process group support for model training
    • Added example script for pretraining vision-language models with distributed process groups
  • Updates

    • Enhanced vision-language model configuration with distributed parallelism options
    • Improved multi-modal training capabilities with advanced sequence packing
  • Bug Fixes

    • Fixed data loading behavior during distributed training

@shifangx shifangx requested review from a team, erhoo82 and malay-nagda as code owners February 13, 2026 03:04
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

This PR adds comprehensive Qwen3VL vision-language model support, including distributed training infrastructure, new vision model components, training step implementations with packed sequence handling, and updates to distributed initialization and process group management.

Changes

Cohort / File(s) Summary
Submodule Update
3rdparty/Megatron-LM
Updated git submodule pointer to a new commit hash.
Example Scripts & Configuration
examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py, examples/recipes/qwen_vl/conf/qwen3_vl_pretrain_override_example.yaml, examples/recipes/qwen_vl/finetune_qwen_vl.py
Added decentralized process group pretraining example, expanded YAML config with distributed settings and profiling options, updated forward step import to qwen3vl_step.
Performance & Script Tooling
scripts/performance/argument_parser.py, scripts/performance/run_script.py
Added "qwen3vl" domain support to CLI argument parser and forward step selection logic.
Training Infrastructure
src/megatron/bridge/training/initialize.py, src/megatron/bridge/training/setup.py, src/megatron/bridge/training/train.py, src/megatron/bridge/training/config.py
Refactored distributed initialization to return multiple process groups and grids; added distributed training branching for vision/language module separation; updated CUDA graph and pipeline communicator handling; modified data parallel size calculations for distributed training mode.
Training Step & Utilities
src/megatron/bridge/training/qwen3vl_step.py, src/megatron/bridge/training/vlm_step.py, src/megatron/bridge/training/utils/packed_seq_utils.py, src/megatron/bridge/training/utils/train_utils.py
Introduced Qwen3VL-specific training step with multi-modal input handling and sequence packing; updated VLM step for fixed padding logic; enhanced packed sequence utilities with pre/post-processing; added rank-in-pg validation helper.
Vision-Language Model Architecture
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py
Added Qwen3VLSelfAttention with flash attention and inference support; introduced Qwen3VLMultimodalRotaryEmbedding with CP-aware RoPE; created Qwen3VLVisionModel with patch embedding and grid-based processing; added vision patch utilities, packing/unpacking, and context-parallel helpers.
Model Components & Configuration
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py, src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py
Enhanced Qwen3VLModel with distributed training, vision module initialization, and input reorganization; updated text model to use multimodal RoPE; introduced Qwen3VLVisionTransformerBlock with deepstack feature support and sharded state dict; added vision config initialization utility.
Model Providers & Bridge Mappings
src/megatron/bridge/models/gpt_provider.py, src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py, src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py
Added vision-related config fields to GPTModelProvider; expanded Qwen3VL providers with vision transformer specs and HF vision model toggle; added vision-language QKV mapping and deepstack merger parameter mappings.
Data & Conversion Utilities
src/megatron/bridge/data/vlm_datasets/hf_provider.py, src/megatron/bridge/models/conversion/param_mapping.py
Added pack_sequences_in_batch field to HFDatasetConversationProvider; corrected documentation example in ConcatenatedQKVMapping.
Data Loader Configuration
src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py
Changed MockVLMConversationProvider dataloader type from "single" to "cyclic" for distributed training compatibility.
Tests
tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py
Updated Qwen3VLModel test instantiation to pass ProcessGroupCollection parameter.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer
    participant Setup
    participant InitDist as Initialize<br/>Distributed
    participant VisionModule
    participant LanguageModule
    participant Communicator

    Trainer->>Setup: setup()
    Setup->>InitDist: torch_dist_init()
    alt use_dist_train enabled
        InitDist->>InitDist: Create HyperCommGrid for vision
        InitDist->>InitDist: Create HyperCommGrid for language
        InitDist->>Communicator: Build MultiModulePipelineCommunicator
        InitDist->>Setup: Return (pg_collection, grid_dict, pg_collection_dict)
        Setup->>Setup: Assign vision/language add flags based on rank membership
        Setup->>Setup: Wire communicator to model config
    else use_dist_train disabled
        InitDist->>Communicator: Build P2PCommunicator
        InitDist->>Setup: Return pg_collection
        Setup->>Setup: Set add_encoder/decoder to True
    end
    Setup->>Trainer: Return configured setup with communicator

    Trainer->>Trainer: Initialize model (Qwen3VL)
    Note over VisionModule,LanguageModule: Vision & Language modules<br/>only active based on add flags
    Trainer->>Trainer: Training loop

    Trainer->>VisionModule: forward() with vision inputs
    VisionModule-->>Trainer: vision embeddings + deepstack features
    Trainer->>LanguageModule: forward() with language inputs + vision embeddings
    LanguageModule-->>Trainer: logits
    Trainer->>Communicator: Handle inter-module communication if needed
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • erhoo82
  • malay-nagda
  • ko3n1g
🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.92% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR description lacks required test results and testing information despite comprehensive test files and major infrastructure changes. Complete PR description with test execution status, convergence verification results, performance metrics, and any known limitations before transitioning from DRAFT status.
✅ Passed checks (3 passed)
Check name Status Explanation
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into main
Title check ✅ Passed The title 'support for training qwen3 vl with dist train' accurately describes the main objective of the changeset, which adds distributed training support for Qwen3-VL models across numerous files.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

Note

Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
src/megatron/bridge/training/vlm_step.py (1)

145-192: 🛠️ Refactor suggestion | 🟠 Major

if True: creates unreachable dead code — clean up before merging.

The else branch (lines 159-192) is unreachable. This looks like a debug/WIP artifact. Either remove the dead branch or replace if True: with the intended condition.

Note: the dead else branch also has a latent bug at line 190 where if attn is not None: is outside the if tokens_or_input is not None: block, making target_len potentially undefined — but since it's dead code, this is moot for now.

Proposed cleanup
     # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length
-    if True:
-        seq_len = cfg.model.seq_length
+    seq_len = cfg.model.seq_length
 
-        tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
-        tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0)
-        ...
-    else:
-        # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length)
-        ...
+    tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
+    tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0)
+    ...

As per coding guidelines: "If code is commented out, include a comment describing its usage and why it is commented out; otherwise remove it as debug code before merging."

src/megatron/bridge/training/initialize.py (2)

134-214: ⚠️ Potential issue | 🟠 Major

torch_dist_init “triple return” contract is inconsistent across branches (lazy_init / skip_mpu_initialization).

torch_dist_init() is annotated to return (ProcessGroupCollection, grid_dict, pg_collection_dict), but it can return:

  • (None, None, None) when skip_mpu_initialization=True, and
  • (finish_mpu_init, None, None) when dist_config.lazy_init=True (callable in slot 0),
    and finish_mpu_init() is annotated as returning ProcessGroupCollection but returns the full 3-tuple.

This mismatch makes it very easy for downstream code to treat the first element as a ProcessGroupCollection and crash (especially after the new destructuring pattern in setup code).

Consider making the return type explicit and self-describing (e.g., a small dataclass/NamedTuple with pg_collection, grid_dict, pg_collection_dict, and optional finish_mpu_init), or at minimum fix the type hints + docstrings so call sites can branch safely on callable(pg_collection) and pg_collection is None.

As per coding guidelines, "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".


363-375: ⚠️ Potential issue | 🟡 Minor

Fix implicit Optional type hints and use built-in tuple generic (ruff RUF013).

Parameters world_size and rank_offset use implicit Optional syntax and should use union types. Return type should use the built-in tuple generic instead of Tuple from typing.

Proposed diff
 def _create_pg_collection(
     model_config: GPTModelProvider | T5ModelProvider,
     num_distributed_optimizer_instances: int,
     get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
     get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
-    world_size: int = None,
-    rank_offset: int = None,
-) -> Tuple[ProcessGroupCollection, HyperCommGrid]:
+    world_size: int | None = None,
+    rank_offset: int | None = None,
+) -> tuple[ProcessGroupCollection, HyperCommGrid]:
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py (1)

15-119: ⚠️ Potential issue | 🟠 Major

Remove unused torch import, fix assertion error handling, and avoid mutating hf_config in-place.

Issues to address:

  • import torch is unused (line 20 shows it's imported but never referenced; only torch.nn.functional is used).
  • assert config.vision_model_type is None, ValueError(...) is incorrect syntax—it raises AssertionError with a ValueError object as the message, and assertions can be disabled at runtime. Use raise NotImplementedError(...) instead to fail fast with a clear error.
  • Mutating hf_config in-place (depth, hidden_size, num_heads, etc.) couples the function to the input object and risks unintended side effects if vision_transformer_config is reused elsewhere. Store these values in local variables or construct a copy of the config to avoid mutation.
Proposed diff
-import torch
 import torch.nn.functional as F
@@
 def get_vision_model_config(config: Qwen3VLTransformerConfig, hf_config):
+    """Populate a Qwen3VLTransformerConfig instance with vision-model settings.
+
+    Note: This function mutates and returns `config`.
+    """
     config.num_moe_experts = None
     config.expert_model_parallel_size = 1
     config.moe_ffn_hidden_size = None
@@
     if config.vision_model_type == "vit_2b":
-        hf_config.depth = 45
-        hf_config.hidden_size = 1536
-        hf_config.num_heads = 16
-        hf_config.intermediate_size = 8960
-        hf_config.patch_size = 16
-        hf_config.spatial_merge_size = 2
-        if hasattr(hf_config, "head_dim"):
-            hf_config.head_dim = 96
+        hf_depth = 45
+        hf_hidden_size = 1536
+        hf_num_heads = 16
+        hf_intermediate_size = 8960
+        hf_patch_size = 16
+        hf_spatial_merge_size = 2
+        hf_head_dim = 96
     else:
-        assert config.vision_model_type is None, ValueError(f"support only vit_2b, but got {config.vision_model_type}")
+        raise NotImplementedError(f"Only vision_model_type='vit_2b' is supported, got: {config.vision_model_type}")
@@
-    config.num_layers = hf_config.depth
-    config.ffn_hidden_size = hf_config.intermediate_size
-    config.num_attention_heads = hf_config.num_heads  # num_heads
+    config.num_layers = hf_depth
+    config.ffn_hidden_size = hf_intermediate_size
+    config.num_attention_heads = hf_num_heads  # num_heads
@@
-    config.hidden_size = hf_config.hidden_size  # hidden_size
+    config.hidden_size = hf_hidden_size  # hidden_size
@@
-    config.patch_size = hf_config.patch_size
+    config.patch_size = hf_patch_size
@@
-    config.spatial_merge_size = hf_config.spatial_merge_size
+    config.spatial_merge_size = hf_spatial_merge_size
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py (1)

30-126: ⚠️ Potential issue | 🟠 Major

Register inv_freq as a buffer and fix type hints to use | instead of Optional.

Two issues need fixing:

  1. inv_freq must be a buffer: Currently stored as a plain tensor attribute (line 65-67), it won't track device moves when module.to(device) is called and won't serialize/deserialize properly with state_dict. The codebase already uses this pattern (see utils.py line 84). Additionally, specifying device=torch.cuda.current_device() at initialization is problematic—the tensor should be created on the default device and follow module movements.

  2. Type hints: Per coding guidelines, use T | None instead of Optional[T]. Line 52 and 54 should use the modern union syntax.

Proposed diff
     def __init__(
         self,
         kv_channels: int,
         rotary_percent: float,
         rotary_interleaved: bool = False,
-        seq_len_interpolation_factor: Optional[float] = None,
+        seq_len_interpolation_factor: float | None = None,
         rotary_base: int = 10000,
-        cp_group: torch.distributed.ProcessGroup = None,
+        cp_group: torch.distributed.ProcessGroup | None = None,
     ) -> None:
         super().__init__()
         
         dim = kv_channels
         if rotary_percent < 1.0:
             dim = int(dim * rotary_percent)
         self.rotary_interleaved = rotary_interleaved
         assert not self.rotary_interleaved, "only support qwen3vl"
         
         self.seq_len_interpolation_factor = seq_len_interpolation_factor
-        self.inv_freq = 1.0 / (
-            rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim)
-        )
+        inv_freq = 1.0 / (rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
         self.is_thd_format = False
         self.cp_group = cp_group
src/megatron/bridge/training/setup.py (1)

52-167: ⚠️ Potential issue | 🔴 Critical

setup(): handle lazy-init and skip-mpu-init returns before accessing pg_collection.pp

When lazy_init=True, initialize_megatron returns a callable as the first element of the tuple; when skip_mpu_initialization=True, it returns None. Directly accessing pg_collection.pp on line 167 will crash in both cases with AttributeError or TypeError.

Add type checking after unpacking:

  • If the first element is callable, invoke it to get the real (pg_collection, grid_dict, pg_collection_dict)
  • If pg_collection is still None, raise a clear error
Proposed fix
     pg_collection, grid_dict, pg_collection_dict = initialize_megatron(
         cfg=cfg,
         get_embedding_ranks=get_embedding_ranks,
         get_position_embedding_ranks=get_position_embedding_ranks,
         restart_store=restart_store,
     )
+    if callable(pg_collection):
+        pg_collection, grid_dict, pg_collection_dict = pg_collection()
+    
+    if pg_collection is None:
+        raise RuntimeError("initialize_megatron did not return a ProcessGroupCollection (pg_collection=None)")
+
     if hasattr(cfg.model, "use_dist_train") and cfg.model.use_dist_train:
tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py (1)

314-323: ⚠️ Potential issue | 🟠 Major

Pass pg_collection to all Qwen3VLModel instances.
Qwen3VLModel now dereferences pg_collection.*; the constructors for model_no_decoder and model_no_pre will raise when pg_collection is omitted. Reuse the collection already created in the test.

🛠️ Proposed fix
         model_no_decoder = Qwen3VLModel(
             vision_transformer_config=vision_transformer_config,
             language_transformer_config=language_transformer_config,
             language_transformer_layer_spec=language_model_layer_spec,
             parallel_output=True,
             pre_process=True,
             post_process=True,
             add_encoder=True,
             add_decoder=False,
+            pg_collection=pg_collection,
         )
@@
         model_no_pre = Qwen3VLModel(
             vision_transformer_config=vision_transformer_config,
             language_transformer_config=language_transformer_config,
             language_transformer_layer_spec=language_model_layer_spec,
             parallel_output=True,
             pre_process=False,
             post_process=True,
             add_encoder=True,
             add_decoder=True,
+            pg_collection=pg_collection,
         )

Also applies to: 366-375

🤖 Fix all issues with AI agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py`:
- Line 37: Replace the import of forward_step from the generic vlm_step with the
Qwen3-VL specific implementation: change the import to use
megatron.bridge.training.qwen3vl_step so the script uses forward_step
implemented in qwen3vl_step; this ensures the Qwen3-VL assertions, data-format
handling (bshd vs thd), position_ids logic, packed sequence handling, and
multimodal input injection (pixel_values and image_grid_thw) are applied instead
of the generic vlm_step behavior.

In `@examples/recipes/qwen_vl/finetune_qwen_vl.py`:
- Around line 106-108: The file unconditionally imports forward_step from
megatron.bridge.training.qwen3vl_step which asserts the model is Qwen3-VL and
breaks Qwen2.5-VL runs; move that import into main() after you determine
recipe/model_family and conditionally import forward_step from the correct
module (if model_family == "Qwen3-VL" import from
megatron.bridge.training.qwen3vl_step else import from
megatron.bridge.training.vlm_step) so the assertion isn’t triggered at import
time, and remove the unused from functools import partial import which is
misplaced after first-party imports.

In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py`:
- Around line 109-114: Guard against pg_collection being None before
dereferencing: check if pg_collection is not None before accessing
pg_collection.cp/tp/pp and assigning self.pg_collection, self.cp_group,
self.tp_group, self.pp_group; if pg_collection is None set those group
attributes to None (or appropriate defaults) and adjust the assert (or replace
it with an explicit check) so you only call hasattr(self.pg_collection, "embd")
when self.pg_collection is not None. Ensure you update any code paths that
assume these groups exist to handle the None/default case.

In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py`:
- Around line 84-92: The constructor dereferences pg_collection.cp when creating
self.rotary_pos_emb (Qwen3VLMultimodalRotaryEmbedding) but pg_collection may be
None; guard this by checking/initializing pg_collection before use or asserting
non-None: e.g., if pg_collection is None create a default with the expected .cp
attribute (or raise a clear error), then pass pg_collection.cp into
Qwen3VLMultimodalRotaryEmbedding; ensure the change touches the place where
rotary_pos_emb is assigned and uses pg_collection only after the guard.

In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py`:
- Around line 86-95: The deepstack_merger_list constructs
Qwen3VLVisionPatchMerger instances without passing the block's tensor-parallel
group, causing linear layers (linear_fc1/linear_fc2) inside those mergers to use
the wrong TP group; update the Qwen3VLVisionPatchMerger construction inside
deepstack_merger_list to include tp_group=self.tp_group (same pattern used in
vision_model where mergers pass tp_group=self.tp_group) so each merger uses the
block's TP group for its linear layers, ensuring consistent tensor-parallel
behavior across config.deepstack_visual_indexes.
- Around line 429-433: The call to sharded_state_dict_default inside
Qwen3VLVisionTransformerBlock's sharded_state_dict loop will raise NameError and
the identity check uses non-PEP8 syntax; fix by importing or providing a
fallback implementation of sharded_state_dict_default (used by
sharded_state_dict when iterating named_children) and replace "not module is
self.layers" / "not module is self.deepstack_merger_list" with "module is not
self.layers" / "module is not self.deepstack_merger_list" respectively; ensure
the imported or fallback sharded_state_dict_default signature matches (module,
prefix, sharded_offsets, metadata) so sharded_state_dict can call it safely.
- Around line 363-435: The method TransformerBlock.sharded_state_dict calls an
undefined helper sharded_state_dict_default and also bypasses parent logic; fix
by first invoking super().sharded_state_dict(prefix, sharded_offsets, metadata)
and merging its result, replace the undefined call by importing/using the
correct helper function (or the module-level utility that actually exists in the
codebase) instead of sharded_state_dict_default when iterating named_children
(reference the symbol sharded_state_dict_default and ensure it matches the real
exported name), and keep using self.pp_group.rank() but verify pp_group
implements rank() where assigned; in short: call super(), swap the undefined
helper for the correct imported helper, and merge results rather than fully
overriding parent behavior.

In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py`:
- Around line 366-372: In split_data_cp_rank, move the None check for the input
tensor before any attribute access so val is returned early if None (i.e., check
`if val is None: return val` before using `val.shape`), and update the cp_rank
annotation from `cp_rank: int = None` to use an explicit Optional type
(`cp_rank: Optional[int] = None`) to comply with PEP 484; keep the rest of the
logic (cp_size assert and cp_rank defaulting via
mpu.get_context_parallel_rank()) unchanged.
- Around line 349-350: Replace the unreachable assert with a real exception: in
the else branch that currently does "assert False, f'should not have
{token_id=}'", raise a ValueError instead (e.g. "raise ValueError(f'should not
have token_id={token_id}')") so the error is not stripped by python -O and
clearly reports the invalid token_id; update the branch where token_id is
handled accordingly.

In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py`:
- Around line 42-50: The constructor currently defaults pg_collection to None
but accesses self.pg_collection.tp in __init__, which can raise AttributeError;
update the __init__ of the class to validate pg_collection (the pg_collection
parameter and self.pg_collection) before accessing .tp — e.g., if pg_collection
is None raise a clear ValueError stating "pg_collection is required" (or
alternatively assign self.tp_group = None when pg_collection is None), then set
self.pg_collection = pg_collection and self.tp_group = self.pg_collection.tp
only after the null check so the attribute access is safe.

In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py`:
- Around line 146-157: Qwen3VLModelProvider (and Qwen3ModelProvider) reference
attributes add_encoder and add_decoder when constructing Qwen3VLModel but those
attributes are never defined; fix by adding dataclass fields add_encoder: bool =
False and add_decoder: bool = False to the provider class (or its parent) so
they're initialized, or alternatively pass explicit boolean literals into the
Qwen3VLModel constructor where add_encoder/add_decoder are currently used
(mimicking nemotron_vl_provider.py); update both the Qwen3VLModelProvider and
the MoE variant locations that reference add_encoder/add_decoder to use the new
fields or hardcoded values.

In `@src/megatron/bridge/training/qwen3vl_step.py`:
- Around line 136-140: The code currently calls
batch.get("visual_inputs").normalized_for_model() without checking for None;
update the block around where "visual_inputs" is handled so you first assign
visual = batch.get("visual_inputs") and only call visual.normalized_for_model()
if visual is not None, otherwise set multi_modal_inputs to None (or skip
creating it) so downstream uses of multi_modal_inputs are guarded; modify the
logic in qwen3vl_step.py where multi_modal_inputs is created to reference the
new visual variable and handle the None case safely.

In `@src/megatron/bridge/training/train.py`:
- Around line 739-740: Guard the last-stage loss reduction by checking whether
cfg.model.p2p_communicator is not None before accessing .is_pp_last_stage; if it
is None (PP size 1), call is_pp_last_stage(pg_collection.pp) instead. Update the
conditional that currently reads cfg.model.p2p_communicator.is_pp_last_stage to
first test cfg.model.p2p_communicator and then fall back to
is_pp_last_stage(pg_collection.pp) so the loss averaging path only runs when the
correct determination of last stage succeeds.

In `@src/megatron/bridge/training/utils/packed_seq_utils.py`:
- Around line 73-74: Move the line "from megatron.core import mpu" up into the
module import block with the other imports (remove the "// Copied from ..."
comment entirely), and ensure the import is colocated with the top-level imports
used by this file; then refactor to avoid duplication by either importing the
existing implementations of preprocess_packed_seqs and postprocess_packed_seqs
from the shared implementation in the qwen_vl utils module or extract a new
shared utility module and have both locations import from it (preserve the
use_fp8_padding behavior when consolidating).
- Around line 130-157: The loop unconditionally uses attention_mask[i] causing a
crash when attention_mask is None; update the loop in the pre_process branch to
guard use of attention_mask: compute d = input_ids[i, attention_mask[i]] if
attention_mask is not None, otherwise set d = input_ids[i, :seqlen] for the
cp_size <= 1 branch (use seqlen from seqlens_in_batch_cpu[i]) and d =
input_ids[i, :seqlen_padded_i] for the cp_size > 1 branch (use seqlen_padded_i
from seqlens_in_batch_padded_cpu[i]) before the subsequent slicing that writes
into input_ids_rmpad so indexing never tries to subscript None; adjust
references around input_ids_rmpad, seqlens_in_batch_cpu,
seqlens_in_batch_padded_cpu, cu_seqlens_padded_cpu, cp_size and cp_rank.
🟡 Minor comments (12)
src/megatron/bridge/training/vlm_step.py-124-125 (1)

124-125: ⚠️ Potential issue | 🟡 Minor

Hardcoded is_first = True / is_last = True bypasses pipeline-stage detection.

This forces all ranks to load labels/loss_mask and visual inputs, regardless of their actual PP stage. If this is intentional for the multi-module VLM distributed training path, please add a comment explaining why pipeline-stage gating was removed (and whether the original is_pp_first_stage/is_pp_last_stage imports on line 21 are now dead).

examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py-41-43 (1)

41-43: ⚠️ Potential issue | 🟡 Minor

Misleading comment: "Qwen3 4B" vs actual 30B config.

Line 42 says "Get the standard Qwen3 4B pretrain config" but line 43 uses qwen3_vl_30b_a3b_pretrain_config.

Fix
-    # Get the standard Qwen3 4B pretrain config with overrides
+    # Get the standard Qwen3-VL 30B MoE pretrain config with overrides
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py-265-307 (1)

265-307: ⚠️ Potential issue | 🟡 Minor

Silence unused cu_seqlens warning and make RoPE-fusion failure explicit.

  • apply_rotary_pos_emb_thd_absolute(..., cu_seqlens, ...) doesn’t use cu_seqlens (ruff ARG001). Rename to _cu_seqlens (or add a targeted noqa) to keep the interface but avoid lint noise.
  • assert not config.apply_rope_fusion will be stripped under python -O; prefer a real exception (NotImplementedError / ValueError) if this is a hard constraint.

As per coding guidelines, "When a feature is not supported (such as audio embeddings), raise an explicit error (e.g., NotImplementedError) instead of silently ignoring the input to fail fast with a clear message." Based on learnings: "when a feature is not supported ... raise an explicit error (e.g., NotImplementedError) instead of silently ignoring".

src/megatron/bridge/training/qwen3vl_step.py-41-47 (1)

41-47: ⚠️ Potential issue | 🟡 Minor

Silence unused-argument warnings and remove the unused variable.
Ruff flags ARG001/F841 in this function.

🛠️ Proposed fix
     batch = next(data_iterator)
+    _ = use_mtp
+    _ = is_first_pp_stage
+    _ = is_last_pp_stage
@@
-        max_seqlen_in_batch = seqlens_in_batch.max().item()

Also applies to: 224-224

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-16-18 (1)

16-18: ⚠️ Potential issue | 🟡 Minor

Align type hints with repo conventions and drop the unused import.
Use built-in generics/union types and remove the unused mpu import to satisfy lint and style rules.

🛠️ Proposed fix
-from typing import List, Dict
+ 
@@
-from megatron.core import InferenceParams, mpu, tensor_parallel
+from megatron.core import InferenceParams, tensor_parallel
@@
-    def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]):
+    def set_input_tensor(self, input_tensor: list[dict[str, torch.Tensor]]):
@@
-        cp_img_num: list[int] = None,
-        images_padded: list[bool] = None,
+        cp_img_num: list[int] | None = None,
+        images_padded: list[bool] | None = None,
As per coding guidelines, Use built-in generics (list, dict, tuple) instead of typing equivalents, and use 'T | None' for nullable types instead of 'Optional[T]'.

Also applies to: 195-200, 300-301

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py-40-41 (1)

40-41: ⚠️ Potential issue | 🟡 Minor

Fix lint issues: unused parameter and f-string without placeholders.
Ruff flags the unused rotary_pos_cos_sin argument and the f prefix on a static string.

🛠️ Proposed fix
-        rotary_pos_cos_sin: Optional[Tensor] = None,
+        _rotary_pos_cos_sin: Optional[Tensor] = None,
@@
-            raise ValueError(f"CUDA graphs must use flash decode with static batching!")
+            raise ValueError("CUDA graphs must use flash decode with static batching!")

Also applies to: 126-127

src/megatron/bridge/training/train.py-39-40 (1)

39-40: ⚠️ Potential issue | 🟡 Minor

Remove unused MultiModulePipelineCommunicator import (F401).
Lint currently flags this as unused.

🛠️ Proposed fix
-from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-152-155 (1)

152-155: ⚠️ Potential issue | 🟡 Minor

Use print_rank_0 for model logging.
This avoids duplicate logs across ranks during distributed training.

🛠️ Proposed fix
-from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF
+from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF
+from megatron.bridge.utils.common_utils import print_rank_0
@@
-                print(f"rank {torch.distributed.get_rank()} use hf vision model")
+                print_rank_0(f"rank {torch.distributed.get_rank()} use hf vision model")
As per coding guidelines, Use 'print_rank_0' for logging in model bridge to avoid duplicate output across ranks.
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-242-242 (1)

242-242: ⚠️ Potential issue | 🟡 Minor

Typo: "flaaten" → "flatten".

-    assert input_ids.dim() == 1, "input_ids should be flaaten"
+    assert input_ids.dim() == 1, "input_ids should be flattened"
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-1-24 (1)

1-24: ⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header and unused imports.

This file is missing the required NVIDIA copyright header. Additionally, static analysis flags several unused imports: dataclass, Union, MultimodalProjector, MegatronModule, build_module, and get_tensor_model_parallel_group_if_none.

Proposed fix
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
-from dataclasses import dataclass
-from typing import Optional, Union
+from __future__ import annotations
 
 import torch
 from megatron.core import InferenceParams
 from megatron.core.models.common.vision_module.vision_module import VisionModule
-from megatron.core.models.vision.multimodal_projector import MultimodalProjector
 from megatron.core.packed_seq_params import PackedSeqParams
 from megatron.core.process_groups_config import ProcessGroupCollection
 from megatron.core.transformer.enums import ModelType
-from megatron.core.transformer.module import MegatronModule
-from megatron.core.transformer.spec_utils import ModuleSpec, build_module
-from megatron.core.utils import get_tensor_model_parallel_group_if_none
+from megatron.core.transformer.spec_utils import ModuleSpec
 from torch import nn
 from torch.nn import functional as F

As per coding guidelines: "Add NVIDIA copyright header to all Python files" and "Use T | None for nullable types instead of Optional[T]".

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-197-245 (1)

197-245: ⚠️ Potential issue | 🟡 Minor

Docstring args/return mismatch with actual signature and return value.

  • Docstring refers to parameter x (Line 208) but the actual parameter is hidden_states.
  • Docstring mentions packed_seq_params (Line 210) which is not in the function signature (it's computed internally on Line 235).
  • The return type annotation says torch.Tensor (Line 203) but the function returns a tuple (hidden_states, deepstack_feature_lists) on Line 245.
Proposed fix
     def forward(
         self,
         hidden_states: Optional[torch.Tensor],
         grid_thw: torch.Tensor,
         inference_params: Optional[InferenceParams] = None,
         extra_block_kwargs: dict = None,
-    ) -> torch.Tensor:
+    ) -> tuple[torch.Tensor, list]:
         """Forward function of the Qwen3 Vision Model. This function passes the input tensors
         through the embedding layer and then the transformer.
 
         Args:
-            x (torch.Tensor): input image/video data of shape [n_tokens, n_dims]
+            hidden_states (torch.Tensor): input image/video data of shape [n_tokens, n_dims]
             grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame
-            packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend
+            inference_params (InferenceParams, optional): inference parameters
+            extra_block_kwargs (dict, optional): additional keyword arguments for the decoder block
 
         Returns:
-            x (torch.Tensor): output after final transformer block of shape [b, s, h].
+            tuple: (hidden_states, deepstack_feature_lists)
         """
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-666-679 (1)

666-679: ⚠️ Potential issue | 🟡 Minor

cu_seqlens_padded is passed undivided to PackedSeqParams, but this is currently a latent issue since the returned packed_seq_params is never used.

While the code structure does create a mismatch when cp_size > 1 (buffer is sized as sum(seqlens_in_batch_padded_cpu) // cp_size but cu_seqlens_padded reflects full, undivided cumulative lengths), this is not a critical runtime bug in the current codebase. All calls to preprocess_packed_seqs() discard the returned packed_seq_params with _ (lines 437, 456, 467, 523 in model.py), and the parameter passed to the language model is always None (line 536).

The data itself is correctly divided during packing (indices are adjusted: start_idx = cu_seqlens_padded_cpu[i] // cp_size on line 650). However, if packed_seq_params were ever used downstream with CP enabled, the mismatch would cause incorrect indexing. Consider either: (1) adjust cu_seqlens_padded and max_seqlen_in_batch for the per-rank view when cp_size > 1, or (2) clarify whether this function is intended to support CP splitting and document the limitation.

🧹 Nitpick comments (17)
scripts/performance/argument_parser.py (1)

148-148: Consider whether qwen3vl belongs as a --domain choice or should be handled via --model_family_name.

The existing domain values (llm, vlm) are generic categories, whereas qwen3vl is model-specific. This sets a precedent where each new VLM variant could require its own domain entry. If the Qwen3-VL training path diverges significantly enough from the general vlm path to justify a separate domain, this is fine — but if the differences are minor, routing through --model_family_name (or a sub-option) would keep the domain list stable.

examples/recipes/qwen_vl/finetune_qwen_vl.py (1)

93-93: Use built-in tuple instead of typing.Tuple for Python 3.10+.

Line 118 uses Tuple[argparse.Namespace, list[str]], mixing typing.Tuple with built-in list. Per coding guidelines, prefer built-in generics.

Proposed fix
-from typing import Tuple

And on line 118:

-def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
+def parse_cli_args() -> tuple[argparse.Namespace, list[str]]:

As per coding guidelines: "Use built-in generics (list, dict, tuple) instead of typing equivalents".

src/megatron/bridge/training/utils/packed_seq_utils.py (2)

77-78: Missing type hint for use_fp8_padding parameter.

Fix
-def preprocess_packed_seqs(
-    input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False
-) -> tuple[torch.Tensor, PackedSeqParams]:
+def preprocess_packed_seqs(
+    input_ids: torch.Tensor, attention_mask: torch.Tensor | None, pre_process: bool = True, use_fp8_padding: bool = False
+) -> tuple[torch.Tensor, PackedSeqParams]:

As per coding guidelines: "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".


196-197: Minor: Use unpacking instead of list concatenation.

Per Ruff RUF005, prefer [batch_size, seq_len, *list(output.shape[2:])].

Fix
-    shape = [batch_size, seq_len] + list(output.shape[2:])  # 1,packed, dim -> batch_size, seq_len, dim
+    shape = [batch_size, seq_len, *output.shape[2:]]  # 1,packed, dim -> batch_size, seq_len, dim
src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py (2)

39-39: Remove unused import get_vision_model_config.

Flake8 confirms get_vision_model_config is imported but never used in this file.

Fix
-from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import get_vision_model_config

139-157: Vision spec construction is duplicated between dense and MoE providers.

Lines 139-144 (dense provider) and lines 307-312 (MoE provider) construct identical vision_transformer_layer_spec and vision_patch_merger_spec objects. Consider extracting a shared helper method (e.g., on a common base class or as a module-level function) to keep these in sync.

src/megatron/bridge/training/initialize.py (1)

589-648: Dist-train PG split: cache rank-membership checks + validate world-size partitioning.

This block calls is_rank_in_pg(...) multiple times for the same collections; please cache results to avoid repeated get_process_group_ranks() calls during init.

Also recommend adding a sanity check that model_config.vision_world_size + model_config.language_world_size == torch.distributed.get_world_size() (or get_world_size_safe()), otherwise you can end up with “neither in vision nor language” ranks that pass the current logic until a later failure.

As per coding guidelines, "Follow the existing code style and conventions as documented in CODING_GUIDELINES.md".

src/megatron/bridge/training/config.py (1)

1318-1413: Add invariant checks for use_dist_train so dp-size computation can’t silently go wrong.

When use_dist_train=True, dp-size is derived from model.language_world_size. That’s fine if model.tensor_model_parallel_size, pipeline_model_parallel_size, and context_parallel_size are also the language-side values at this point. If those fields still reflect a “combined” or vision-side topology, dp-size will be incorrect.

Suggestion: add a small validation in set_data_parallel_size() (or model.finalize()) that language_world_size % (tp*pp*cp) == 0 with the exact values used here, and raise a clear error if not.

As per coding guidelines, "Be explicit about required vs optional fields in configuration objects; do not add arbitrary defaults".

src/megatron/bridge/models/gpt_provider.py (1)

200-210: Use T | None for the new nullable config fields.
Align new fields with the repo’s nullable type-hint convention.

🛠️ Proposed fix
-    vision_model_type: Optional[str] = None
+    vision_model_type: str | None = None
@@
-    dist_train_vision_chunk_size: Optional[int] = 1
-    vision_world_size: Optional[int] = None
-    language_world_size: Optional[int] = None
-    vision_tensor_model_parallel_size: Optional[int] = None
-    vision_pipeline_model_parallel_size: Optional[int] = None
-    vision_context_parallel_size: Optional[int] = None
-    vision_expert_tensor_parallel_size: Optional[int] = None
-    vision_expert_model_parallel_size: Optional[int] = None
+    dist_train_vision_chunk_size: int | None = 1
+    vision_world_size: int | None = None
+    language_world_size: int | None = None
+    vision_tensor_model_parallel_size: int | None = None
+    vision_pipeline_model_parallel_size: int | None = None
+    vision_context_parallel_size: int | None = None
+    vision_expert_tensor_parallel_size: int | None = None
+    vision_expert_model_parallel_size: int | None = None
As per coding guidelines, Use 'T | None' for nullable types instead of 'Optional[T]'.
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py (4)

16-17: Use modern union type syntax per coding guidelines.

The coding guidelines require T | None instead of Optional[T] and X | Y instead of Union[X, Y] (Python 3.10+). These typing imports are used throughout the file (lines 92–94, 103, etc.).

Suggested fix
-from typing import Optional, Union
+from __future__ import annotations

Then replace all Optional[X] with X | None and Union[X, Y] with X | Y throughout the file.


569-596: AllGatherVisionEmbeddings.backward — integer start_idx when cp_rank == 0 may cause slicing issues.

On Line 593, when cp_rank == 0, torch.cat(seqlens_on_cp_ranks[:0]) produces an empty tensor, and .sum() returns a tensor (scalar 0), not a Python int. Then on Line 595, grad_output[start_idx:end_idx] uses tensor indices. While this works in practice, mixing tensor and int index types is fragile. More importantly, the ctx.save_for_backward stores all seqlens_on_cp_ranks tensors. If these are large lists, this is fine, but verify the tensors are detached and on the correct device.

Safer start_idx computation
-        start_idx = torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum() if cp_rank != 0 else 0
+        start_idx = int(torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum().item()) if cp_rank != 0 else 0

700-701: postprocess_packed_seqs: redundant .cpu() call may trigger an extra D2H sync.

attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() on Line 701 explicitly moves to CPU. If attention_mask is already on GPU, .tolist() alone will perform the D2H transfer. The extra .cpu() call is harmless but redundant. More importantly, this is a synchronization point — consider batching it with other D2H transfers above if performance matters here.


290-292: Magic default token IDs should reference the config, not be hardcoded.

image_token_id: int = 151655 and video_token_id: int = 151656 are hardcoded defaults in reorganize_inputs. These IDs are model-specific and already defined in Qwen3VLTransformerConfig. Hardcoding them here risks silent mismatch if the config changes.

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py (2)

134-195: fast_pos_embed_interpolate: potential performance concern with .tolist() in inner loop.

Lines 170–171 call .tolist() and .extend() inside the per-image loop, converting GPU tensors to Python lists element by element. For a large number of images/grids, this creates many small D2H synchronizations and Python list operations. Consider accumulating the index and weight tensors on GPU and converting once after the loop.

Also, the unused loop variable t on Line 140 (flagged by Ruff B007) can be replaced with _ to signal intent.


239-244: torch.split followed by torch.cat is a no-op.

Lines 242–244 split hidden_states into chunks and immediately concatenate them back. Unless there's a side effect or a planned insertion between split and cat, this can be removed.

-        split_sizes = (grid_thw.prod(-1) // self.spatial_merge_size**2).tolist()
-        hidden_states = torch.split(hidden_states, split_sizes)
-        hidden_states = torch.cat(hidden_states, dim=0)

If this is a placeholder for future logic, add a comment explaining it.

src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py (2)

64-64: Remove or document commented-out code.

Lines 64 and 76 contain commented-out model_comm_pgs parameters without any explanation. As per coding guidelines, commented-out code should include a comment describing its purpose and why it's commented out, or be removed before merging.

-        # model_comm_pgs: ModelCommProcessGroups = None,
...
-            # model_comm_pgs=model_comm_pgs,

Also applies to: 76-76


336-339: O(n) index lookup self.deepstack_visual_indexes.index(l_no) on every layer forward pass.

Both in the checkpointed path (Line 134) and the non-checkpointed path (Line 337), .index(l_no) performs a linear scan of deepstack_visual_indexes for every matching layer. Consider converting to a dict mapping layer_no → deepstack_idx in __init__ for O(1) lookup.

Proposed optimization in __init__
         self.deepstack_visual_indexes = config.deepstack_visual_indexes
+        self._deepstack_index_map = {
+            l_no: idx for idx, l_no in enumerate(config.deepstack_visual_indexes)
+        }

Then use self._deepstack_index_map[l_no] instead of self.deepstack_visual_indexes.index(l_no).


from megatron.bridge.recipes.qwen_vl.qwen3_vl import qwen3_vl_30b_a3b_pretrain_config
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.vlm_step import forward_step
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if vlm_step.forward_step handles Qwen3VL models (look for visual_inputs injection)
echo "=== Checking vlm_step.py for visual_inputs handling ==="
rg -n 'visual_inputs' src/megatron/bridge/training/vlm_step.py

echo ""
echo "=== Checking qwen3vl_step.py for Qwen3VL-specific logic ==="
rg -n 'Qwen3VL\|qwen3vl\|qwen3_vl' src/megatron/bridge/training/qwen3vl_step.py | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 680


🏁 Script executed:

#!/bin/bash
# Check if qwen3vl_step.py exists
echo "=== Looking for qwen3vl_step.py ==="
find . -name "*qwen3vl_step*" -o -name "*qwen3_vl_step*"

echo ""
echo "=== Check what files exist in training module ==="
ls -la src/megatron/bridge/training/ | grep -E "step\.py|qwen"

echo ""
echo "=== If qwen3vl_step exists, show its full content ==="
if [ -f "src/megatron/bridge/training/qwen3vl_step.py" ]; then
  wc -l src/megatron/bridge/training/qwen3vl_step.py
  cat -n src/megatron/bridge/training/qwen3vl_step.py
fi

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 18492


Import forward_step from qwen3vl_step instead of vlm_step.

This example imports forward_step from vlm_step, which is a generic VLM module. However, since this is a Qwen3-VL specific pretraining script, it should use from megatron.bridge.training.qwen3vl_step import forward_step instead.

qwen3vl_step contains Qwen3-VL–specific optimizations that vlm_step lacks:

  • Assertion that model is Qwen3-VL (line 266-267)
  • Data format handling tailored for Qwen3-VL (bshd vs thd)
  • Position_ids calculation in model forward (not in forward_step)
  • Packed sequence parameter handling specific to Qwen3-VL
  • Proper multi-modal input injection for pixel_values and image_grid_thw

Using the generic vlm_step.forward_step would miss these optimizations and potentially cause training to diverge from the intended Qwen3-VL behavior.

🤖 Prompt for AI Agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py` at line 37,
Replace the import of forward_step from the generic vlm_step with the Qwen3-VL
specific implementation: change the import to use
megatron.bridge.training.qwen3vl_step so the script uses forward_step
implemented in qwen3vl_step; this ensures the Qwen3-VL assertions, data-format
handling (bshd vs thd), position_ids logic, packed sequence handling, and
multimodal input injection (pixel_values and image_grid_thw) are applied instead
of the generic vlm_step behavior.

Comment on lines +106 to +108
from megatron.bridge.training.qwen3vl_step import forward_step
from megatron.bridge.utils.common_utils import get_rank_safe

from functools import partial
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, verify if qwen3vl_step.py has an assertion for Qwen3-VL models
rg -n 'assert.*[Qq]wen' --type=py -C3 src/megatron/bridge/training/qwen3vl_step.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 389


🏁 Script executed:

# Check if partial is actually used anywhere in finetune_qwen_vl.py
rg -n 'partial' examples/recipes/qwen_vl/finetune_qwen_vl.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 104


🏁 Script executed:

# View the imports section of finetune_qwen_vl.py to verify ordering
head -120 examples/recipes/qwen_vl/finetune_qwen_vl.py | tail -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1489


🏁 Script executed:

# Check how the script determines recipe_name and model_family
rg -n 'recipe_name|model_family' examples/recipes/qwen_vl/finetune_qwen_vl.py -A2 -B2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 833


🏁 Script executed:

# Verify if vlm_step.py exists and has a forward_step compatible with Qwen2.5-VL
find . -name "vlm_step.py" -type f

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 113


🏁 Script executed:

# Check what forward_step is in vlm_step.py if it exists
fd -e py vlm_step | head -5 | xargs cat

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 20280


🏁 Script executed:

# Check if there are any other uses of qwen25_vl_recipes in the file to understand full context
rg -n 'qwen25_vl_recipes|Qwen2.5' examples/recipes/qwen_vl/finetune_qwen_vl.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 699


Qwen2.5-VL recipes will fail: forward_step is unconditionally imported from qwen3vl_step, which asserts model is Qwen3-VL.

The script supports both Qwen2.5-VL and Qwen3-VL (line 19, recipe selection lines 204–211), but forward_step is imported from qwen3vl_step before main() runs. The qwen3vl_step.forward_step asserts the model is Qwen3-VL (line 267), causing Qwen2.5-VL recipes to fail.

Move the import into main() after the recipe/model_family is determined and select the appropriate module:

if model_family == "Qwen3-VL":
    from megatron.bridge.training.qwen3vl_step import forward_step
else:
    from megatron.bridge.training.vlm_step import forward_step

Additionally, remove from functools import partial on line 108—it is unused and is a stdlib import incorrectly placed after first-party imports, violating import ordering guidelines.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 108-108: 'functools.partial' imported but unused

(F401)

🤖 Prompt for AI Agents
In `@examples/recipes/qwen_vl/finetune_qwen_vl.py` around lines 106 - 108, The
file unconditionally imports forward_step from
megatron.bridge.training.qwen3vl_step which asserts the model is Qwen3-VL and
breaks Qwen2.5-VL runs; move that import into main() after you determine
recipe/model_family and conditionally import forward_step from the correct
module (if model_family == "Qwen3-VL" import from
megatron.bridge.training.qwen3vl_step else import from
megatron.bridge.training.vlm_step) so the assertion isn’t triggered at import
time, and remove the unused from functools import partial import which is
misplaced after first-party imports.

Comment on lines +16 to +18
from megatron.core.transformer.attention import *

from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import apply_rotary_pos_emb_absolute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Replace the star import with explicit dependencies (ruff F403/F405).
The wildcard import is causing undefined-name lint errors. Import the exact symbols used in this file to keep the namespace explicit and pass linting.

🛠️ Proposed fix (explicit imports)
-from megatron.core.transformer.attention import *
+from typing import Optional, Tuple, Union
+
+from torch import Tensor
+from megatron.core.inference.contexts import BaseInferenceContext
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.transformer.attention import (
+    HAVE_FA3,
+    SelfAttention,
+    is_fa_min_version,
+    rearrange,
+)
+from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push
🧰 Tools
🪛 Ruff (0.15.0)

[error] 16-16: from megatron.core.transformer.attention import * used; unable to detect undefined names

(F403)

Comment on lines +109 to +114
# process groups
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp
assert hasattr(self.pg_collection, "embd"), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against pg_collection=None before dereferencing.
pg_collection is optional in the signature but is immediately accessed (pg_collection.cp). If callers omit it, this will raise.

🛠️ Proposed fix
-        # process groups
-        self.pg_collection = pg_collection
+        if pg_collection is None:
+            pg_collection = get_pg_collection(self)
+        # process groups
+        self.pg_collection = pg_collection
         self.cp_group = pg_collection.cp
         self.tp_group = pg_collection.tp
         self.pp_group = pg_collection.pp
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py` around lines
109 - 114, Guard against pg_collection being None before dereferencing: check if
pg_collection is not None before accessing pg_collection.cp/tp/pp and assigning
self.pg_collection, self.cp_group, self.tp_group, self.pp_group; if
pg_collection is None set those group attributes to None (or appropriate
defaults) and adjust the assert (or replace it with an explicit check) so you
only call hasattr(self.pg_collection, "embd") when self.pg_collection is not
None. Ensure you update any code paths that assume these groups exist to handle
the None/default case.

Comment on lines +129 to +263
# Slightly modified from Qwen3VLModel.get_rope_index
def get_rope_index(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""

def forward(
self,
position_ids: torch.Tensor,
mrope_section: List[int] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
**kwargs,
) -> Tensor:
"""Forward pass for non-MoE Qwen3-VL RoPE.
# Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1

Args:
position_ids: Position IDs tensor
mrope_section: Optional mrope section (if not provided, uses self.mrope_section)
"""
if mrope_section is None:
mrope_section = self.mrope_section

if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
emb = emb[..., None, :].transpose(0, 1).contiguous()
_ = packed_seq_params # packed sequences not supported yet
return emb
if packed_seq_params is not None and attention_mask is None and input_ids is not None:
# Build an attention mask from packed sequence metadata when one is not provided.
# cu_seqlens_q entries are cumulative lengths; their diffs give per-sample lengths.
cu_seqlens = packed_seq_params.cu_seqlens_q
if cu_seqlens is not None and cu_seqlens.numel() >= 2:
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
attention_mask = torch.zeros_like(input_ids, dtype=input_ids.dtype)
max_len = attention_mask.shape[1]
for i, seq_len in enumerate(seq_lens.tolist()):
valid = min(int(seq_len), max_len)
attention_mask[i, :valid] = 1
else:
# Fallback to a dense mask if packed metadata is missing.
attention_mask = torch.ones_like(input_ids)

mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image

else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

# t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)

return position_ids, mrope_position_deltas

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

get_rope_index has multiple likely runtime errors (tensor→int, shadowing, None-handling).

Key blockers:

  • image_nums / video_nums are tensors (.sum()), then used in range(image_nums + video_nums) → TypeError.
  • The loop for i, input_ids in enumerate(total_input_ids): shadows the input_ids argument, making the function harder to reason about and risking bugs.
  • In the “no image/video grids” branch, input_ids can be None while attention_mask is also Noneinput_ids.shape[...] will crash.
  • vision_start_indices + 1 can go OOB if a <vision_start> token appears at the last position.
Targeted fix sketch (minimal)
 def get_rope_index(
@@
-    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
         total_input_ids = input_ids
@@
-        for i, input_ids in enumerate(total_input_ids):
-            input_ids = input_ids[attention_mask[i] == 1]
+        for i, sample_input_ids in enumerate(total_input_ids):
+            sample_input_ids = sample_input_ids[attention_mask[i] == 1]
@@
-            vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
-            vision_tokens = input_ids[vision_start_indices + 1]
-            image_nums = (vision_tokens == image_token_id).sum()
-            video_nums = (vision_tokens == video_token_id).sum()
-            input_tokens = input_ids.tolist()
+            vision_start_indices = torch.argwhere(sample_input_ids == vision_start_token_id).squeeze(1)
+            vision_start_indices = vision_start_indices[vision_start_indices + 1 < sample_input_ids.numel()]
+            vision_tokens = sample_input_ids[vision_start_indices + 1]
+            image_nums = int((vision_tokens == image_token_id).sum().item())
+            video_nums = int((vision_tokens == video_token_id).sum().item())
+            input_tokens = sample_input_ids.tolist()
@@
-            remain_images, remain_videos = image_nums, video_nums
+            remain_images, remain_videos = image_nums, video_nums
             for _ in range(image_nums + video_nums):
@@
-            mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
-        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+            mrope_position_deltas.append(llm_positions.max() + 1 - total_input_ids.shape[1])
+        mrope_position_deltas = torch.stack(mrope_position_deltas).unsqueeze(1)
@@
-    else:
+    else:
+        if attention_mask is None and input_ids is None:
+            raise ValueError("input_ids must be provided when attention_mask is None")
As per coding guidelines, "Avoid shadowing variables declared in an outer scope".

Comment on lines 146 to 157
model = Qwen3VLModel(
language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec,
vision_transformer_config=hf_vision_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
vision_patch_merger_spec=vision_patch_merger_spec,
pre_process=pre_process,
post_process=post_process,
pg_collection=self._pg_collection,
add_encoder=self.add_encoder,
add_decoder=self.add_decoder,
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's check the structure of the file under review and its parent classes
cat -n src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | head -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2534


🏁 Script executed:

# Search for _pg_collection, add_encoder, add_decoder in qwen3 directory
rg -n '_pg_collection|add_encoder|add_decoder' --type=py -C2 src/megatron/bridge/models/qwen3/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 144


🏁 Script executed:

# Check if these attributes are defined in the qwen3_vl_provider file itself
rg -n 'class.*Provider|def __init__|self._pg_collection|self.add_encoder|self.add_decoder' --type=py src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 465


🏁 Script executed:

# Find where Qwen3ModelProvider and Qwen3MoEModelProvider are defined
find src/megatron/bridge/models -name "*.py" -type f | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1080


🏁 Script executed:

# Search for Qwen3ModelProvider definition
rg -n 'class Qwen3ModelProvider|class Qwen3MoEModelProvider' --type=py src/megatron/bridge/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1105


🏁 Script executed:

# Check the imports in qwen3_vl_provider.py to understand where these come from
head -50 src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2184


🏁 Script executed:

# Check the parent classes Qwen3ModelProvider and Qwen3MoEModelProvider
sed -n '257,283p' src/megatron/bridge/models/qwen/qwen_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 940


🏁 Script executed:

# Check Qwen3MoEModelProvider
sed -n '364,400p' src/megatron/bridge/models/qwen/qwen_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1281


🏁 Script executed:

# Check what GPTModelProvider (the grandparent) defines - look for __init__ and these attributes
rg -n 'class GPTModelProvider|def __init__|self._pg_collection|self.add_encoder|self.add_decoder' --type=py src/megatron/bridge/models/ -A5 | grep -E '(class GPTModelProvider|__init__|_pg_collection|add_encoder|add_decoder)' | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3403


🏁 Script executed:

# Find GPTModelProvider definition and its __init__
rg -n 'class GPTModelProvider' --type=py src/megatron/bridge/models/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 195


🏁 Script executed:

# Get the full definition of GPTModelProvider including __init__
sed -n '1,100p' src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3938


🏁 Script executed:

# Search for _pg_collection and add_encoder/add_decoder in GPTModelProvider
rg -n '_pg_collection|add_encoder|add_decoder' --type=py -B2 -A2 src/megatron/bridge/models/gpt_provider.py | head -60

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2117


🏁 Script executed:

# Search for add_encoder and add_decoder definitions in provider classes
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check ModelProviderMixin and TransformerConfig for these attributes
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/model_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check the full Qwen3VLModelProvider class definition to see if it defines these
sed -n '47,200p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 6127


🏁 Script executed:

# Search for add_encoder and add_decoder class attributes or properties in the entire bridge models
rg -n 'add_encoder\s*[:=]|add_decoder\s*[:=]' --type=py src/megatron/bridge/models/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1698


🏁 Script executed:

# Check if they might be properties or methods
rg -n '@property' --type=py -A3 src/megatron/bridge/models/gpt_provider.py | grep -E '(property|add_encoder|add_decoder)'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Search in transformer_config.py or model_provider.py
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/transformer_config.py src/megatron/bridge/models/model_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if add_encoder and add_decoder are defined as properties or anywhere in parent classes
rg -n 'class TransformerConfig|class ModelProviderMixin' --type=py src/megatron/bridge/models/ -A20 | grep -E '(class|add_encoder|add_decoder|@property)'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 534


🏁 Script executed:

# Look at the parent class definitions more closely
sed -n '1,300p' src/megatron/bridge/models/model_provider.py | head -100

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3973


🏁 Script executed:

# Compare with nemotron_vl_provider to see how it handles these parameters
sed -n '120,140p' src/megatron/bridge/models/nemotron_vl/nemotron_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1106


🏁 Script executed:

# Check if add_encoder and add_decoder are defined as dataclass fields in Qwen3VLModelProvider or parents
sed -n '47,200p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n 'add_encoder\|add_decoder'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 162


🏁 Script executed:

# Check the full Qwen3ModelProvider definition for these fields
rg -n '@dataclass' --type=py -A100 src/megatron/bridge/models/qwen/qwen_provider.py | grep -E '(dataclass|add_encoder|add_decoder)'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 427


🏁 Script executed:

# Check if these are defined in TransformerConfig or GPTModelProvider
sed -n '100,250p' src/megatron/bridge/models/gpt_provider.py | grep -n 'add_encoder\|add_decoder'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Look at the full GPTModelProvider class, especially around the provide method
sed -n '127,350p' src/megatron/bridge/models/gpt_provider.py | grep -n 'add_encoder\|add_decoder\|class\|def provide'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 321


🏁 Script executed:

# Check what dataclass fields are defined in GPTModelProvider itself
sed -n '127,280p' src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 7495


🏁 Script executed:

# Check the complete provide method in GPTModelProvider to see how add_encoder/add_decoder are used
sed -n '215,380p' src/megatron/bridge/models/gpt_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 8034


🏁 Script executed:

# Get a comprehensive view - search for where add_encoder/add_decoder might be defined as fields
rg -n '@dataclass|add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.py | head -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 175


🏁 Script executed:

# Final confirmation: search comprehensively for where add_encoder and add_decoder are defined as dataclass fields
rg -n 'add_encoder\s*:|add_decoder\s*:' --type=py src/megatron/bridge/models/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 662


🏁 Script executed:

# Check the complete Qwen3VLModelProvider class definition to confirm these fields are not there
sed -n '47,300p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n 'add_encoder\|add_decoder'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 162


🏁 Script executed:

# Check if there's an __init__ method in Qwen3VLModelProvider that initializes them
sed -n '47,300p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n '__init__'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Let's verify one more time if Qwen3VLModelProvider or its parents define add_encoder/add_decoder as fields
grep -n '@dataclass\|class Qwen3VLModelProvider\|add_encoder\|add_decoder' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 334


🏁 Script executed:

# Check complete parent class hierarchy for these fields
rg -n '@dataclass' --type=py -A150 src/megatron/bridge/models/qwen/qwen_provider.py | grep -E '(class Qwen3ModelProvider|add_encoder|add_decoder)' | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 436


🏁 Script executed:

# One final check - see if these might be properties defined somewhere
rg -n '@property' --type=py -A3 src/megatron/bridge/models/gpt_provider.py | grep -B1 -A2 'add_encoder\|add_decoder'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if add_encoder and add_decoder might be inherited from TransformerConfig
sed -n '50,150p' src/megatron/bridge/models/transformer_config.py | grep -n 'add_encoder\|add_decoder\|@dataclass'

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 97


🏁 Script executed:

# One more comprehensive check - look at all occurrences in the provider hierarchy
rg -n 'class.*Provider|add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.py src/megatron/bridge/models/transformer_config.py | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 717


🏁 Script executed:

# Look at the nemotron example to see how it handles this (hardcoded True)
sed -n '100,145p' src/megatron/bridge/models/nemotron_vl/nemotron_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2328


🏁 Script executed:

# Let's view the complete Qwen3VLModelProvider class definition to ensure we're not missing anything
sed -n '46,185p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5431


add_encoder and add_decoder are not defined in the provider class hierarchy.

These attributes are referenced at lines 155-156 (and lines 324-325 in the MoE variant) but are never initialized as dataclass fields in Qwen3VLModelProvider, Qwen3ModelProvider, or their parents. The _pg_collection attribute is properly inherited from GPTModelProvider, but add_encoder and add_decoder must be defined. Either add them as dataclass fields with appropriate defaults or pass hardcoded values like nemotron_vl_provider.py does.

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py` around lines 146 -
157, Qwen3VLModelProvider (and Qwen3ModelProvider) reference attributes
add_encoder and add_decoder when constructing Qwen3VLModel but those attributes
are never defined; fix by adding dataclass fields add_encoder: bool = False and
add_decoder: bool = False to the provider class (or its parent) so they're
initialized, or alternatively pass explicit boolean literals into the
Qwen3VLModel constructor where add_encoder/add_decoder are currently used
(mimicking nemotron_vl_provider.py); update both the Qwen3VLModelProvider and
the MoE variant locations that reference add_encoder/add_decoder to use the new
fields or hardcoded values.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check this ^^ @shifangx

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Comment on lines +136 to +140
if "visual_inputs" in batch:
# convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw"
# TODO(jinliangl): add video support
multi_modal_inputs = batch.get("visual_inputs").normalized_for_model()
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against visual_inputs=None before calling normalized_for_model().
If the dataset provides visual_inputs=None (e.g., text-only samples), the current code will raise.

🛠️ Proposed fix
-    if "visual_inputs" in batch:
-        # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw"
-        # TODO(jinliangl): add video support
-        multi_modal_inputs = batch.get("visual_inputs").normalized_for_model()
-    else:
-        multi_modal_inputs = {}
+    visual_inputs = batch.get("visual_inputs")
+    if visual_inputs is not None:
+        # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw"
+        # TODO(jinliangl): add video support
+        multi_modal_inputs = visual_inputs.normalized_for_model()
+    else:
+        multi_modal_inputs = {}
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/qwen3vl_step.py` around lines 136 - 140, The
code currently calls batch.get("visual_inputs").normalized_for_model() without
checking for None; update the block around where "visual_inputs" is handled so
you first assign visual = batch.get("visual_inputs") and only call
visual.normalized_for_model() if visual is not None, otherwise set
multi_modal_inputs to None (or skip creating it) so downstream uses of
multi_modal_inputs are guarded; modify the logic in qwen3vl_step.py where
multi_modal_inputs is created to reference the new visual variable and handle
the None case safely.

Comment on lines +739 to 740
if cfg.model.p2p_communicator.is_pp_last_stage:
# Average loss across microbatches.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard last-stage loss reduction when p2p_communicator is unset.
cfg.model.p2p_communicator can be None when PP size is 1; accessing .is_pp_last_stage will raise. Fall back to is_pp_last_stage(pg_collection.pp).

🛠️ Proposed fix
-    if cfg.model.p2p_communicator.is_pp_last_stage:
+    p2p = cfg.model.p2p_communicator
+    if (p2p is not None and p2p.is_pp_last_stage) or (p2p is None and is_pp_last_stage(pg_collection.pp)):
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/train.py` around lines 739 - 740, Guard the
last-stage loss reduction by checking whether cfg.model.p2p_communicator is not
None before accessing .is_pp_last_stage; if it is None (PP size 1), call
is_pp_last_stage(pg_collection.pp) instead. Update the conditional that
currently reads cfg.model.p2p_communicator.is_pp_last_stage to first test
cfg.model.p2p_communicator and then fall back to
is_pp_last_stage(pg_collection.pp) so the loss averaging path only runs when the
correct determination of last stage succeeds.

Comment on lines +73 to +74
# Copied from verl/verl/models/mcore/util.py
from megatron.core import mpu
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Move import to the top of the file and drop the "Copied from" comment.

from megatron.core import mpu should be with the other imports (after line 19). A bare "Copied from" comment doesn't explain why the code was duplicated — per coding guidelines, comments should describe usage and rationale.

Additionally, both preprocess_packed_seqs and postprocess_packed_seqs are near-duplicates of the versions in src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py (with the addition of use_fp8_padding). Consider consolidating into a single shared location to avoid drift.

Proposed import fix
 import torch
 import math
 from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core import mpu
 
 ...
 
-# Copied from verl/verl/models/mcore/util.py
-from megatron.core import mpu
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py` around lines 73 - 74,
Move the line "from megatron.core import mpu" up into the module import block
with the other imports (remove the "// Copied from ..." comment entirely), and
ensure the import is colocated with the top-level imports used by this file;
then refactor to avoid duplication by either importing the existing
implementations of preprocess_packed_seqs and postprocess_packed_seqs from the
shared implementation in the qwen_vl utils module or extract a new shared
utility module and have both locations import from it (preserve the
use_fp8_padding behavior when consolidating).

Comment on lines +130 to +157
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
# Use Python int, so no GPU→CPU sync in the loop
if cp_size <= 1:
seqlen = seqlens_in_batch_cpu[i]
start_idx = cu_seqlens_padded_cpu[i]
input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
continue

seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
seqlen = seqlen_padded_i // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded_cpu[i] // cp_size
# split to 2 chunks
d = input_ids[i, attention_mask[i]]
input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[
half_seqlen * cp_rank : half_seqlen * (cp_rank + 1)
]

remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1)
remain_end = seqlen_padded_i - half_seqlen * cp_rank
remain_end = min(remain_end, d.shape[0])
remain_len = remain_end - remain_start
if remain_len > 0:
input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[
remain_start:remain_end
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: attention_mask[i] crashes when attention_mask is None.

Lines 88-92 handle attention_mask is None for the seqlens_in_batch computation, but the loop body on lines 137 and 145 unconditionally indexes attention_mask[i] for boolean masking. When attention_mask is None and pre_process is True, this will raise TypeError: 'NoneType' object is not subscriptable.

Proposed fix
     if pre_process:
         input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
         for i in range(batch_size):
             # Use Python int, so no GPU→CPU sync in the loop
             if cp_size <= 1:
                 seqlen = seqlens_in_batch_cpu[i]
                 start_idx = cu_seqlens_padded_cpu[i]
-                input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
+                if attention_mask is not None:
+                    input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
+                else:
+                    input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, :seqlen]
                 continue
 
             seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
             seqlen = seqlen_padded_i // cp_size
             half_seqlen = seqlen // 2
             start_idx = cu_seqlens_padded_cpu[i] // cp_size
             # split to 2 chunks
-            d = input_ids[i, attention_mask[i]]
+            d = input_ids[i, attention_mask[i]] if attention_mask is not None else input_ids[i, :seqlens_in_batch_cpu[i]]
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py` around lines 130 -
157, The loop unconditionally uses attention_mask[i] causing a crash when
attention_mask is None; update the loop in the pre_process branch to guard use
of attention_mask: compute d = input_ids[i, attention_mask[i]] if attention_mask
is not None, otherwise set d = input_ids[i, :seqlen] for the cp_size <= 1 branch
(use seqlen from seqlens_in_batch_cpu[i]) and d = input_ids[i, :seqlen_padded_i]
for the cp_size > 1 branch (use seqlen_padded_i from
seqlens_in_batch_padded_cpu[i]) before the subsequent slicing that writes into
input_ids_rmpad so indexing never tries to subscript None; adjust references
around input_ids_rmpad, seqlens_in_batch_cpu, seqlens_in_batch_padded_cpu,
cu_seqlens_padded_cpu, cp_size and cp_rank.

@shifangx shifangx force-pushed the shifang/qwen3_vl_dist_train branch from c86973a to adb12c4 Compare February 14, 2026 14:58
@shifangx shifangx force-pushed the shifang/qwen3_vl_dist_train branch from adb12c4 to c86973a Compare February 15, 2026 00:55
@shifangx shifangx force-pushed the shifang/qwen3_vl_dist_train branch from c86973a to 3382cce Compare February 15, 2026 12:47
@shifangx shifangx force-pushed the shifang/qwen3_vl_dist_train branch from 3382cce to 2f55a42 Compare February 15, 2026 12:56
@shifangx shifangx changed the title [draft]support for training qwen3 vl with dist train support for training qwen3 vl with dist train Feb 23, 2026
@shifangx
Copy link
Copy Markdown
Contributor Author

Is there anyone can help to review this pr?

@malay-nagda
Copy link
Copy Markdown
Contributor

Is there anyone can help to review this pr?

maybe @yaoyu-33 ?


_pg_collection: Optional[ProcessGroupCollection] = None

# vision model type will be used to override the vision model config.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit weird add here the distrain configs and vision configs. find a better approach, this shouldn't be exposed to all gpt_provider models, only qwen3_vl, you can edit this in qwen3 provider

This method calculates the data parallel size needed by setup methods, without
triggering full validation or finalization of Megatron Core configs.
"""
if hasattr(self.model, "use_dist_train") and self.model.use_dist_train:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar here, it's a bit over hi-jacking the workflow, from distrain.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe create a new method for set_data_parallel_size for disttrain, feels more clean.

# Distributed - ensure data_parallel_size is calculated (might already be set by set_data_parallel_size)
if not hasattr(self, "data_parallel_size") or self.data_parallel_size is None:
world_size = get_world_size_safe()
if hasattr(self.model, "use_dist_train") and self.model.use_dist_train:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly see if possible to not do
"hasattr(self.model, "use_dist_train")"

it adds a overhead from a corner use case to the mainstream of validate. It's okay to put all use_distrain_stuff all together in one block instead of adding conditions to current mainstream path.

ask cursor to come up with some ideas here.

pg_collection_dict: A dictionary mapping module names to ProcessGroupCollections.
"""

def finish_mpu_init() -> ProcessGroupCollection:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the return needs to be updated?
"-> ProcessGroupCollection:"

An optional callable to finish MPU initialization if skip_mpu_initialization
or lazy_mpu_init is True, otherwise None.
pg_collection: The process group collection initialized for this run.
grid_dict: A dictionary mapping module names to HyperCommGrids.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain when grid_dict and pg_collection_dict wouldn't be none in docstring.

dp = torch.distributed.get_world_size() // (tp * pp * cp)
print(f"> initialized HyperCommGrid with tp={tp}, pp={pp}, cp={cp}, dp={dp}")
return pg_collection
if hasattr(model_config, "use_dist_train") and model_config.use_dist_train:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall, I feel use_distrain is a train config? i.e. cfg.train.use_dist_train. Check where we set fsdp? model config seems weird, but it might be okay

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this section is a bit too long to create pgs for multimodal. offload to a internal util func. maybe "_create_dist_train_pgs"

vision_model_type: Optional[str] = None

# parameters for DistTrain
use_dist_train: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a train config

# parameters for DistTrain
use_dist_train: bool = False
dist_train_vision_chunk_size: Optional[int] = 1
vision_world_size: Optional[int] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe dist_train_dataclass

num_floating_point_operations_model = flop_utils.num_floating_point_operations(config, batch_size=1)
p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config)
if config.model.use_dist_train:
p2p_communicator = config.model.p2p_communicator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's not a real config in config, something will not be saved during checkpointing, use _ prefix, set it as a private attr.
"config.model._p2p_communicator"

@yaoyu-33 yaoyu-33 added the needs-author Author action is required before review or merge can continue label Mar 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-author Author action is required before review or merge can continue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants