support for training qwen3 vl with dist train#2367
support for training qwen3 vl with dist train#2367shifangx wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds comprehensive Qwen3VL vision-language model support, including distributed training infrastructure, new vision model components, training step implementations with packed sequence handling, and updates to distributed initialization and process group management. Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer
participant Setup
participant InitDist as Initialize<br/>Distributed
participant VisionModule
participant LanguageModule
participant Communicator
Trainer->>Setup: setup()
Setup->>InitDist: torch_dist_init()
alt use_dist_train enabled
InitDist->>InitDist: Create HyperCommGrid for vision
InitDist->>InitDist: Create HyperCommGrid for language
InitDist->>Communicator: Build MultiModulePipelineCommunicator
InitDist->>Setup: Return (pg_collection, grid_dict, pg_collection_dict)
Setup->>Setup: Assign vision/language add flags based on rank membership
Setup->>Setup: Wire communicator to model config
else use_dist_train disabled
InitDist->>Communicator: Build P2PCommunicator
InitDist->>Setup: Return pg_collection
Setup->>Setup: Set add_encoder/decoder to True
end
Setup->>Trainer: Return configured setup with communicator
Trainer->>Trainer: Initialize model (Qwen3VL)
Note over VisionModule,LanguageModule: Vision & Language modules<br/>only active based on add flags
Trainer->>Trainer: Training loop
Trainer->>VisionModule: forward() with vision inputs
VisionModule-->>Trainer: vision embeddings + deepstack features
Trainer->>LanguageModule: forward() with language inputs + vision embeddings
LanguageModule-->>Trainer: logits
Trainer->>Communicator: Handle inter-module communication if needed
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 17
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
src/megatron/bridge/training/vlm_step.py (1)
145-192: 🛠️ Refactor suggestion | 🟠 Major
if True:creates unreachable dead code — clean up before merging.The
elsebranch (lines 159-192) is unreachable. This looks like a debug/WIP artifact. Either remove the dead branch or replaceif True:with the intended condition.Note: the dead
elsebranch also has a latent bug at line 190 whereif attn is not None:is outside theif tokens_or_input is not None:block, makingtarget_lenpotentially undefined — but since it's dead code, this is moot for now.Proposed cleanup
# When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length - if True: - seq_len = cfg.model.seq_length + seq_len = cfg.model.seq_length - tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") - tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) - ... - else: - # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length) - ... + tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") + tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) + ...As per coding guidelines: "If code is commented out, include a comment describing its usage and why it is commented out; otherwise remove it as debug code before merging."
src/megatron/bridge/training/initialize.py (2)
134-214:⚠️ Potential issue | 🟠 Majortorch_dist_init “triple return” contract is inconsistent across branches (lazy_init / skip_mpu_initialization).
torch_dist_init()is annotated to return(ProcessGroupCollection, grid_dict, pg_collection_dict), but it can return:
(None, None, None)whenskip_mpu_initialization=True, and(finish_mpu_init, None, None)whendist_config.lazy_init=True(callable in slot 0),
andfinish_mpu_init()is annotated as returningProcessGroupCollectionbut returns the full 3-tuple.This mismatch makes it very easy for downstream code to treat the first element as a
ProcessGroupCollectionand crash (especially after the new destructuring pattern in setup code).Consider making the return type explicit and self-describing (e.g., a small dataclass/NamedTuple with
pg_collection,grid_dict,pg_collection_dict, and optionalfinish_mpu_init), or at minimum fix the type hints + docstrings so call sites can branch safely oncallable(pg_collection)andpg_collection is None.As per coding guidelines, "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".
363-375:⚠️ Potential issue | 🟡 MinorFix implicit Optional type hints and use built-in tuple generic (ruff RUF013).
Parameters
world_sizeandrank_offsetuse implicit Optional syntax and should use union types. Return type should use the built-intuplegeneric instead ofTuplefrom typing.Proposed diff
def _create_pg_collection( model_config: GPTModelProvider | T5ModelProvider, num_distributed_optimizer_instances: int, get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, - world_size: int = None, - rank_offset: int = None, -) -> Tuple[ProcessGroupCollection, HyperCommGrid]: + world_size: int | None = None, + rank_offset: int | None = None, +) -> tuple[ProcessGroupCollection, HyperCommGrid]:src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py (1)
15-119:⚠️ Potential issue | 🟠 MajorRemove unused
torchimport, fix assertion error handling, and avoid mutatinghf_configin-place.Issues to address:
import torchis unused (line 20 shows it's imported but never referenced; onlytorch.nn.functionalis used).assert config.vision_model_type is None, ValueError(...)is incorrect syntax—it raisesAssertionErrorwith aValueErrorobject as the message, and assertions can be disabled at runtime. Useraise NotImplementedError(...)instead to fail fast with a clear error.- Mutating
hf_configin-place (depth, hidden_size, num_heads, etc.) couples the function to the input object and risks unintended side effects ifvision_transformer_configis reused elsewhere. Store these values in local variables or construct a copy of the config to avoid mutation.Proposed diff
-import torch import torch.nn.functional as F @@ def get_vision_model_config(config: Qwen3VLTransformerConfig, hf_config): + """Populate a Qwen3VLTransformerConfig instance with vision-model settings. + + Note: This function mutates and returns `config`. + """ config.num_moe_experts = None config.expert_model_parallel_size = 1 config.moe_ffn_hidden_size = None @@ if config.vision_model_type == "vit_2b": - hf_config.depth = 45 - hf_config.hidden_size = 1536 - hf_config.num_heads = 16 - hf_config.intermediate_size = 8960 - hf_config.patch_size = 16 - hf_config.spatial_merge_size = 2 - if hasattr(hf_config, "head_dim"): - hf_config.head_dim = 96 + hf_depth = 45 + hf_hidden_size = 1536 + hf_num_heads = 16 + hf_intermediate_size = 8960 + hf_patch_size = 16 + hf_spatial_merge_size = 2 + hf_head_dim = 96 else: - assert config.vision_model_type is None, ValueError(f"support only vit_2b, but got {config.vision_model_type}") + raise NotImplementedError(f"Only vision_model_type='vit_2b' is supported, got: {config.vision_model_type}") @@ - config.num_layers = hf_config.depth - config.ffn_hidden_size = hf_config.intermediate_size - config.num_attention_heads = hf_config.num_heads # num_heads + config.num_layers = hf_depth + config.ffn_hidden_size = hf_intermediate_size + config.num_attention_heads = hf_num_heads # num_heads @@ - config.hidden_size = hf_config.hidden_size # hidden_size + config.hidden_size = hf_hidden_size # hidden_size @@ - config.patch_size = hf_config.patch_size + config.patch_size = hf_patch_size @@ - config.spatial_merge_size = hf_config.spatial_merge_size + config.spatial_merge_size = hf_spatial_merge_sizesrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py (1)
30-126:⚠️ Potential issue | 🟠 MajorRegister
inv_freqas a buffer and fix type hints to use|instead ofOptional.Two issues need fixing:
inv_freqmust be a buffer: Currently stored as a plain tensor attribute (line 65-67), it won't track device moves whenmodule.to(device)is called and won't serialize/deserialize properly with state_dict. The codebase already uses this pattern (seeutils.pyline 84). Additionally, specifyingdevice=torch.cuda.current_device()at initialization is problematic—the tensor should be created on the default device and follow module movements.Type hints: Per coding guidelines, use
T | Noneinstead ofOptional[T]. Line 52 and 54 should use the modern union syntax.Proposed diff
def __init__( self, kv_channels: int, rotary_percent: float, rotary_interleaved: bool = False, - seq_len_interpolation_factor: Optional[float] = None, + seq_len_interpolation_factor: float | None = None, rotary_base: int = 10000, - cp_group: torch.distributed.ProcessGroup = None, + cp_group: torch.distributed.ProcessGroup | None = None, ) -> None: super().__init__() dim = kv_channels if rotary_percent < 1.0: dim = int(dim * rotary_percent) self.rotary_interleaved = rotary_interleaved assert not self.rotary_interleaved, "only support qwen3vl" self.seq_len_interpolation_factor = seq_len_interpolation_factor - self.inv_freq = 1.0 / ( - rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim) - ) + inv_freq = 1.0 / (rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.is_thd_format = False self.cp_group = cp_groupsrc/megatron/bridge/training/setup.py (1)
52-167:⚠️ Potential issue | 🔴 Criticalsetup(): handle lazy-init and skip-mpu-init returns before accessing
pg_collection.ppWhen
lazy_init=True,initialize_megatronreturns a callable as the first element of the tuple; whenskip_mpu_initialization=True, it returnsNone. Directly accessingpg_collection.ppon line 167 will crash in both cases withAttributeErrororTypeError.Add type checking after unpacking:
- If the first element is callable, invoke it to get the real
(pg_collection, grid_dict, pg_collection_dict)- If
pg_collectionis stillNone, raise a clear errorProposed fix
pg_collection, grid_dict, pg_collection_dict = initialize_megatron( cfg=cfg, get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, restart_store=restart_store, ) + if callable(pg_collection): + pg_collection, grid_dict, pg_collection_dict = pg_collection() + + if pg_collection is None: + raise RuntimeError("initialize_megatron did not return a ProcessGroupCollection (pg_collection=None)") + if hasattr(cfg.model, "use_dist_train") and cfg.model.use_dist_train:tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py (1)
314-323:⚠️ Potential issue | 🟠 MajorPass
pg_collectionto allQwen3VLModelinstances.
Qwen3VLModelnow dereferencespg_collection.*; the constructors formodel_no_decoderandmodel_no_prewill raise whenpg_collectionis omitted. Reuse the collection already created in the test.🛠️ Proposed fix
model_no_decoder = Qwen3VLModel( vision_transformer_config=vision_transformer_config, language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_model_layer_spec, parallel_output=True, pre_process=True, post_process=True, add_encoder=True, add_decoder=False, + pg_collection=pg_collection, ) @@ model_no_pre = Qwen3VLModel( vision_transformer_config=vision_transformer_config, language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_model_layer_spec, parallel_output=True, pre_process=False, post_process=True, add_encoder=True, add_decoder=True, + pg_collection=pg_collection, )Also applies to: 366-375
🤖 Fix all issues with AI agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py`:
- Line 37: Replace the import of forward_step from the generic vlm_step with the
Qwen3-VL specific implementation: change the import to use
megatron.bridge.training.qwen3vl_step so the script uses forward_step
implemented in qwen3vl_step; this ensures the Qwen3-VL assertions, data-format
handling (bshd vs thd), position_ids logic, packed sequence handling, and
multimodal input injection (pixel_values and image_grid_thw) are applied instead
of the generic vlm_step behavior.
In `@examples/recipes/qwen_vl/finetune_qwen_vl.py`:
- Around line 106-108: The file unconditionally imports forward_step from
megatron.bridge.training.qwen3vl_step which asserts the model is Qwen3-VL and
breaks Qwen2.5-VL runs; move that import into main() after you determine
recipe/model_family and conditionally import forward_step from the correct
module (if model_family == "Qwen3-VL" import from
megatron.bridge.training.qwen3vl_step else import from
megatron.bridge.training.vlm_step) so the assertion isn’t triggered at import
time, and remove the unused from functools import partial import which is
misplaced after first-party imports.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py`:
- Around line 109-114: Guard against pg_collection being None before
dereferencing: check if pg_collection is not None before accessing
pg_collection.cp/tp/pp and assigning self.pg_collection, self.cp_group,
self.tp_group, self.pp_group; if pg_collection is None set those group
attributes to None (or appropriate defaults) and adjust the assert (or replace
it with an explicit check) so you only call hasattr(self.pg_collection, "embd")
when self.pg_collection is not None. Ensure you update any code paths that
assume these groups exist to handle the None/default case.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py`:
- Around line 84-92: The constructor dereferences pg_collection.cp when creating
self.rotary_pos_emb (Qwen3VLMultimodalRotaryEmbedding) but pg_collection may be
None; guard this by checking/initializing pg_collection before use or asserting
non-None: e.g., if pg_collection is None create a default with the expected .cp
attribute (or raise a clear error), then pass pg_collection.cp into
Qwen3VLMultimodalRotaryEmbedding; ensure the change touches the place where
rotary_pos_emb is assigned and uses pg_collection only after the guard.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py`:
- Around line 86-95: The deepstack_merger_list constructs
Qwen3VLVisionPatchMerger instances without passing the block's tensor-parallel
group, causing linear layers (linear_fc1/linear_fc2) inside those mergers to use
the wrong TP group; update the Qwen3VLVisionPatchMerger construction inside
deepstack_merger_list to include tp_group=self.tp_group (same pattern used in
vision_model where mergers pass tp_group=self.tp_group) so each merger uses the
block's TP group for its linear layers, ensuring consistent tensor-parallel
behavior across config.deepstack_visual_indexes.
- Around line 429-433: The call to sharded_state_dict_default inside
Qwen3VLVisionTransformerBlock's sharded_state_dict loop will raise NameError and
the identity check uses non-PEP8 syntax; fix by importing or providing a
fallback implementation of sharded_state_dict_default (used by
sharded_state_dict when iterating named_children) and replace "not module is
self.layers" / "not module is self.deepstack_merger_list" with "module is not
self.layers" / "module is not self.deepstack_merger_list" respectively; ensure
the imported or fallback sharded_state_dict_default signature matches (module,
prefix, sharded_offsets, metadata) so sharded_state_dict can call it safely.
- Around line 363-435: The method TransformerBlock.sharded_state_dict calls an
undefined helper sharded_state_dict_default and also bypasses parent logic; fix
by first invoking super().sharded_state_dict(prefix, sharded_offsets, metadata)
and merging its result, replace the undefined call by importing/using the
correct helper function (or the module-level utility that actually exists in the
codebase) instead of sharded_state_dict_default when iterating named_children
(reference the symbol sharded_state_dict_default and ensure it matches the real
exported name), and keep using self.pp_group.rank() but verify pp_group
implements rank() where assigned; in short: call super(), swap the undefined
helper for the correct imported helper, and merge results rather than fully
overriding parent behavior.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py`:
- Around line 366-372: In split_data_cp_rank, move the None check for the input
tensor before any attribute access so val is returned early if None (i.e., check
`if val is None: return val` before using `val.shape`), and update the cp_rank
annotation from `cp_rank: int = None` to use an explicit Optional type
(`cp_rank: Optional[int] = None`) to comply with PEP 484; keep the rest of the
logic (cp_size assert and cp_rank defaulting via
mpu.get_context_parallel_rank()) unchanged.
- Around line 349-350: Replace the unreachable assert with a real exception: in
the else branch that currently does "assert False, f'should not have
{token_id=}'", raise a ValueError instead (e.g. "raise ValueError(f'should not
have token_id={token_id}')") so the error is not stripped by python -O and
clearly reports the invalid token_id; update the branch where token_id is
handled accordingly.
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py`:
- Around line 42-50: The constructor currently defaults pg_collection to None
but accesses self.pg_collection.tp in __init__, which can raise AttributeError;
update the __init__ of the class to validate pg_collection (the pg_collection
parameter and self.pg_collection) before accessing .tp — e.g., if pg_collection
is None raise a clear ValueError stating "pg_collection is required" (or
alternatively assign self.tp_group = None when pg_collection is None), then set
self.pg_collection = pg_collection and self.tp_group = self.pg_collection.tp
only after the null check so the attribute access is safe.
In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py`:
- Around line 146-157: Qwen3VLModelProvider (and Qwen3ModelProvider) reference
attributes add_encoder and add_decoder when constructing Qwen3VLModel but those
attributes are never defined; fix by adding dataclass fields add_encoder: bool =
False and add_decoder: bool = False to the provider class (or its parent) so
they're initialized, or alternatively pass explicit boolean literals into the
Qwen3VLModel constructor where add_encoder/add_decoder are currently used
(mimicking nemotron_vl_provider.py); update both the Qwen3VLModelProvider and
the MoE variant locations that reference add_encoder/add_decoder to use the new
fields or hardcoded values.
In `@src/megatron/bridge/training/qwen3vl_step.py`:
- Around line 136-140: The code currently calls
batch.get("visual_inputs").normalized_for_model() without checking for None;
update the block around where "visual_inputs" is handled so you first assign
visual = batch.get("visual_inputs") and only call visual.normalized_for_model()
if visual is not None, otherwise set multi_modal_inputs to None (or skip
creating it) so downstream uses of multi_modal_inputs are guarded; modify the
logic in qwen3vl_step.py where multi_modal_inputs is created to reference the
new visual variable and handle the None case safely.
In `@src/megatron/bridge/training/train.py`:
- Around line 739-740: Guard the last-stage loss reduction by checking whether
cfg.model.p2p_communicator is not None before accessing .is_pp_last_stage; if it
is None (PP size 1), call is_pp_last_stage(pg_collection.pp) instead. Update the
conditional that currently reads cfg.model.p2p_communicator.is_pp_last_stage to
first test cfg.model.p2p_communicator and then fall back to
is_pp_last_stage(pg_collection.pp) so the loss averaging path only runs when the
correct determination of last stage succeeds.
In `@src/megatron/bridge/training/utils/packed_seq_utils.py`:
- Around line 73-74: Move the line "from megatron.core import mpu" up into the
module import block with the other imports (remove the "// Copied from ..."
comment entirely), and ensure the import is colocated with the top-level imports
used by this file; then refactor to avoid duplication by either importing the
existing implementations of preprocess_packed_seqs and postprocess_packed_seqs
from the shared implementation in the qwen_vl utils module or extract a new
shared utility module and have both locations import from it (preserve the
use_fp8_padding behavior when consolidating).
- Around line 130-157: The loop unconditionally uses attention_mask[i] causing a
crash when attention_mask is None; update the loop in the pre_process branch to
guard use of attention_mask: compute d = input_ids[i, attention_mask[i]] if
attention_mask is not None, otherwise set d = input_ids[i, :seqlen] for the
cp_size <= 1 branch (use seqlen from seqlens_in_batch_cpu[i]) and d =
input_ids[i, :seqlen_padded_i] for the cp_size > 1 branch (use seqlen_padded_i
from seqlens_in_batch_padded_cpu[i]) before the subsequent slicing that writes
into input_ids_rmpad so indexing never tries to subscript None; adjust
references around input_ids_rmpad, seqlens_in_batch_cpu,
seqlens_in_batch_padded_cpu, cu_seqlens_padded_cpu, cp_size and cp_rank.
🟡 Minor comments (12)
src/megatron/bridge/training/vlm_step.py-124-125 (1)
124-125:⚠️ Potential issue | 🟡 MinorHardcoded
is_first = True/is_last = Truebypasses pipeline-stage detection.This forces all ranks to load labels/loss_mask and visual inputs, regardless of their actual PP stage. If this is intentional for the multi-module VLM distributed training path, please add a comment explaining why pipeline-stage gating was removed (and whether the original
is_pp_first_stage/is_pp_last_stageimports on line 21 are now dead).examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py-41-43 (1)
41-43:⚠️ Potential issue | 🟡 MinorMisleading comment: "Qwen3 4B" vs actual 30B config.
Line 42 says "Get the standard Qwen3 4B pretrain config" but line 43 uses
qwen3_vl_30b_a3b_pretrain_config.Fix
- # Get the standard Qwen3 4B pretrain config with overrides + # Get the standard Qwen3-VL 30B MoE pretrain config with overridessrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py-265-307 (1)
265-307:⚠️ Potential issue | 🟡 MinorSilence unused
cu_seqlenswarning and make RoPE-fusion failure explicit.
apply_rotary_pos_emb_thd_absolute(..., cu_seqlens, ...)doesn’t usecu_seqlens(ruff ARG001). Rename to_cu_seqlens(or add a targeted noqa) to keep the interface but avoid lint noise.assert not config.apply_rope_fusionwill be stripped underpython -O; prefer a real exception (NotImplementedError/ValueError) if this is a hard constraint.As per coding guidelines, "When a feature is not supported (such as audio embeddings), raise an explicit error (e.g., NotImplementedError) instead of silently ignoring the input to fail fast with a clear message." Based on learnings: "when a feature is not supported ... raise an explicit error (e.g., NotImplementedError) instead of silently ignoring".
src/megatron/bridge/training/qwen3vl_step.py-41-47 (1)
41-47:⚠️ Potential issue | 🟡 MinorSilence unused-argument warnings and remove the unused variable.
Ruff flags ARG001/F841 in this function.🛠️ Proposed fix
batch = next(data_iterator) + _ = use_mtp + _ = is_first_pp_stage + _ = is_last_pp_stage @@ - max_seqlen_in_batch = seqlens_in_batch.max().item()Also applies to: 224-224
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-16-18 (1)
16-18:⚠️ Potential issue | 🟡 MinorAlign type hints with repo conventions and drop the unused import.
Use built-in generics/union types and remove the unusedmpuimport to satisfy lint and style rules.As per coding guidelines, Use built-in generics (list, dict, tuple) instead of typing equivalents, and use 'T | None' for nullable types instead of 'Optional[T]'.🛠️ Proposed fix
-from typing import List, Dict + @@ -from megatron.core import InferenceParams, mpu, tensor_parallel +from megatron.core import InferenceParams, tensor_parallel @@ - def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + def set_input_tensor(self, input_tensor: list[dict[str, torch.Tensor]]): @@ - cp_img_num: list[int] = None, - images_padded: list[bool] = None, + cp_img_num: list[int] | None = None, + images_padded: list[bool] | None = None,Also applies to: 195-200, 300-301
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py-40-41 (1)
40-41:⚠️ Potential issue | 🟡 MinorFix lint issues: unused parameter and f-string without placeholders.
Ruff flags the unusedrotary_pos_cos_sinargument and thefprefix on a static string.🛠️ Proposed fix
- rotary_pos_cos_sin: Optional[Tensor] = None, + _rotary_pos_cos_sin: Optional[Tensor] = None, @@ - raise ValueError(f"CUDA graphs must use flash decode with static batching!") + raise ValueError("CUDA graphs must use flash decode with static batching!")Also applies to: 126-127
src/megatron/bridge/training/train.py-39-40 (1)
39-40:⚠️ Potential issue | 🟡 MinorRemove unused
MultiModulePipelineCommunicatorimport (F401).
Lint currently flags this as unused.🛠️ Proposed fix
-from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicatorsrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py-152-155 (1)
152-155:⚠️ Potential issue | 🟡 MinorUse
print_rank_0for model logging.
This avoids duplicate logs across ranks during distributed training.As per coding guidelines, Use 'print_rank_0' for logging in model bridge to avoid duplicate output across ranks.🛠️ Proposed fix
-from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF +from megatron.bridge.utils.common_utils import print_rank_0 @@ - print(f"rank {torch.distributed.get_rank()} use hf vision model") + print_rank_0(f"rank {torch.distributed.get_rank()} use hf vision model")src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-242-242 (1)
242-242:⚠️ Potential issue | 🟡 MinorTypo: "flaaten" → "flatten".
- assert input_ids.dim() == 1, "input_ids should be flaaten" + assert input_ids.dim() == 1, "input_ids should be flattened"src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-1-24 (1)
1-24:⚠️ Potential issue | 🟡 MinorMissing NVIDIA copyright header and unused imports.
This file is missing the required NVIDIA copyright header. Additionally, static analysis flags several unused imports:
dataclass,Union,MultimodalProjector,MegatronModule,build_module, andget_tensor_model_parallel_group_if_none.Proposed fix
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + -from dataclasses import dataclass -from typing import Optional, Union +from __future__ import annotations import torch from megatron.core import InferenceParams from megatron.core.models.common.vision_module.vision_module import VisionModule -from megatron.core.models.vision.multimodal_projector import MultimodalProjector from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.enums import ModelType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.utils import get_tensor_model_parallel_group_if_none +from megatron.core.transformer.spec_utils import ModuleSpec from torch import nn from torch.nn import functional as FAs per coding guidelines: "Add NVIDIA copyright header to all Python files" and "Use
T | Nonefor nullable types instead ofOptional[T]".src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py-197-245 (1)
197-245:⚠️ Potential issue | 🟡 MinorDocstring args/return mismatch with actual signature and return value.
- Docstring refers to parameter
x(Line 208) but the actual parameter ishidden_states.- Docstring mentions
packed_seq_params(Line 210) which is not in the function signature (it's computed internally on Line 235).- The return type annotation says
torch.Tensor(Line 203) but the function returns a tuple(hidden_states, deepstack_feature_lists)on Line 245.Proposed fix
def forward( self, hidden_states: Optional[torch.Tensor], grid_thw: torch.Tensor, inference_params: Optional[InferenceParams] = None, extra_block_kwargs: dict = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, list]: """Forward function of the Qwen3 Vision Model. This function passes the input tensors through the embedding layer and then the transformer. Args: - x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + hidden_states (torch.Tensor): input image/video data of shape [n_tokens, n_dims] grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame - packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + inference_params (InferenceParams, optional): inference parameters + extra_block_kwargs (dict, optional): additional keyword arguments for the decoder block Returns: - x (torch.Tensor): output after final transformer block of shape [b, s, h]. + tuple: (hidden_states, deepstack_feature_lists) """src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py-666-679 (1)
666-679:⚠️ Potential issue | 🟡 Minor
cu_seqlens_paddedis passed undivided toPackedSeqParams, but this is currently a latent issue since the returnedpacked_seq_paramsis never used.While the code structure does create a mismatch when
cp_size > 1(buffer is sized assum(seqlens_in_batch_padded_cpu) // cp_sizebutcu_seqlens_paddedreflects full, undivided cumulative lengths), this is not a critical runtime bug in the current codebase. All calls topreprocess_packed_seqs()discard the returnedpacked_seq_paramswith_(lines 437, 456, 467, 523 in model.py), and the parameter passed to the language model is alwaysNone(line 536).The data itself is correctly divided during packing (indices are adjusted:
start_idx = cu_seqlens_padded_cpu[i] // cp_sizeon line 650). However, ifpacked_seq_paramswere ever used downstream with CP enabled, the mismatch would cause incorrect indexing. Consider either: (1) adjustcu_seqlens_paddedandmax_seqlen_in_batchfor the per-rank view whencp_size > 1, or (2) clarify whether this function is intended to support CP splitting and document the limitation.
🧹 Nitpick comments (17)
scripts/performance/argument_parser.py (1)
148-148: Consider whetherqwen3vlbelongs as a--domainchoice or should be handled via--model_family_name.The existing domain values (
llm,vlm) are generic categories, whereasqwen3vlis model-specific. This sets a precedent where each new VLM variant could require its own domain entry. If the Qwen3-VL training path diverges significantly enough from the generalvlmpath to justify a separate domain, this is fine — but if the differences are minor, routing through--model_family_name(or a sub-option) would keep the domain list stable.examples/recipes/qwen_vl/finetune_qwen_vl.py (1)
93-93: Use built-intupleinstead oftyping.Tuplefor Python 3.10+.Line 118 uses
Tuple[argparse.Namespace, list[str]], mixingtyping.Tuplewith built-inlist. Per coding guidelines, prefer built-in generics.Proposed fix
-from typing import TupleAnd on line 118:
-def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: +def parse_cli_args() -> tuple[argparse.Namespace, list[str]]:As per coding guidelines: "Use built-in generics (list, dict, tuple) instead of typing equivalents".
src/megatron/bridge/training/utils/packed_seq_utils.py (2)
77-78: Missing type hint foruse_fp8_paddingparameter.Fix
-def preprocess_packed_seqs( - input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False -) -> tuple[torch.Tensor, PackedSeqParams]: +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor | None, pre_process: bool = True, use_fp8_padding: bool = False +) -> tuple[torch.Tensor, PackedSeqParams]:As per coding guidelines: "Use type hints for function arguments and return types" and "Use 'T | None' for nullable types instead of 'Optional[T]'".
196-197: Minor: Use unpacking instead of list concatenation.Per Ruff RUF005, prefer
[batch_size, seq_len, *list(output.shape[2:])].Fix
- shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + shape = [batch_size, seq_len, *output.shape[2:]] # 1,packed, dim -> batch_size, seq_len, dimsrc/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py (2)
39-39: Remove unused importget_vision_model_config.Flake8 confirms
get_vision_model_configis imported but never used in this file.Fix
-from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import get_vision_model_config
139-157: Vision spec construction is duplicated between dense and MoE providers.Lines 139-144 (dense provider) and lines 307-312 (MoE provider) construct identical
vision_transformer_layer_specandvision_patch_merger_specobjects. Consider extracting a shared helper method (e.g., on a common base class or as a module-level function) to keep these in sync.src/megatron/bridge/training/initialize.py (1)
589-648: Dist-train PG split: cache rank-membership checks + validate world-size partitioning.This block calls
is_rank_in_pg(...)multiple times for the same collections; please cache results to avoid repeatedget_process_group_ranks()calls during init.Also recommend adding a sanity check that
model_config.vision_world_size + model_config.language_world_size == torch.distributed.get_world_size()(orget_world_size_safe()), otherwise you can end up with “neither in vision nor language” ranks that pass the current logic until a later failure.As per coding guidelines, "Follow the existing code style and conventions as documented in CODING_GUIDELINES.md".
src/megatron/bridge/training/config.py (1)
1318-1413: Add invariant checks foruse_dist_trainso dp-size computation can’t silently go wrong.When
use_dist_train=True, dp-size is derived frommodel.language_world_size. That’s fine ifmodel.tensor_model_parallel_size,pipeline_model_parallel_size, andcontext_parallel_sizeare also the language-side values at this point. If those fields still reflect a “combined” or vision-side topology, dp-size will be incorrect.Suggestion: add a small validation in
set_data_parallel_size()(ormodel.finalize()) thatlanguage_world_size % (tp*pp*cp) == 0with the exact values used here, and raise a clear error if not.As per coding guidelines, "Be explicit about required vs optional fields in configuration objects; do not add arbitrary defaults".
src/megatron/bridge/models/gpt_provider.py (1)
200-210: UseT | Nonefor the new nullable config fields.
Align new fields with the repo’s nullable type-hint convention.As per coding guidelines, Use 'T | None' for nullable types instead of 'Optional[T]'.🛠️ Proposed fix
- vision_model_type: Optional[str] = None + vision_model_type: str | None = None @@ - dist_train_vision_chunk_size: Optional[int] = 1 - vision_world_size: Optional[int] = None - language_world_size: Optional[int] = None - vision_tensor_model_parallel_size: Optional[int] = None - vision_pipeline_model_parallel_size: Optional[int] = None - vision_context_parallel_size: Optional[int] = None - vision_expert_tensor_parallel_size: Optional[int] = None - vision_expert_model_parallel_size: Optional[int] = None + dist_train_vision_chunk_size: int | None = 1 + vision_world_size: int | None = None + language_world_size: int | None = None + vision_tensor_model_parallel_size: int | None = None + vision_pipeline_model_parallel_size: int | None = None + vision_context_parallel_size: int | None = None + vision_expert_tensor_parallel_size: int | None = None + vision_expert_model_parallel_size: int | None = Nonesrc/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py (4)
16-17: Use modern union type syntax per coding guidelines.The coding guidelines require
T | Noneinstead ofOptional[T]andX | Yinstead ofUnion[X, Y](Python 3.10+). These typing imports are used throughout the file (lines 92–94, 103, etc.).Suggested fix
-from typing import Optional, Union +from __future__ import annotationsThen replace all
Optional[X]withX | NoneandUnion[X, Y]withX | Ythroughout the file.
569-596:AllGatherVisionEmbeddings.backward— integerstart_idxwhencp_rank == 0may cause slicing issues.On Line 593, when
cp_rank == 0,torch.cat(seqlens_on_cp_ranks[:0])produces an empty tensor, and.sum()returns a tensor (scalar0), not a Pythonint. Then on Line 595,grad_output[start_idx:end_idx]uses tensor indices. While this works in practice, mixing tensor and int index types is fragile. More importantly, thectx.save_for_backwardstores allseqlens_on_cp_rankstensors. If these are large lists, this is fine, but verify the tensors are detached and on the correct device.Safer start_idx computation
- start_idx = torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum() if cp_rank != 0 else 0 + start_idx = int(torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum().item()) if cp_rank != 0 else 0
700-701:postprocess_packed_seqs: redundant.cpu()call may trigger an extra D2H sync.
attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist()on Line 701 explicitly moves to CPU. Ifattention_maskis already on GPU,.tolist()alone will perform the D2H transfer. The extra.cpu()call is harmless but redundant. More importantly, this is a synchronization point — consider batching it with other D2H transfers above if performance matters here.
290-292: Magic default token IDs should reference the config, not be hardcoded.
image_token_id: int = 151655andvideo_token_id: int = 151656are hardcoded defaults inreorganize_inputs. These IDs are model-specific and already defined inQwen3VLTransformerConfig. Hardcoding them here risks silent mismatch if the config changes.src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py (2)
134-195:fast_pos_embed_interpolate: potential performance concern with.tolist()in inner loop.Lines 170–171 call
.tolist()and.extend()inside the per-image loop, converting GPU tensors to Python lists element by element. For a large number of images/grids, this creates many small D2H synchronizations and Python list operations. Consider accumulating the index and weight tensors on GPU and converting once after the loop.Also, the unused loop variable
ton Line 140 (flagged by Ruff B007) can be replaced with_to signal intent.
239-244:torch.splitfollowed bytorch.catis a no-op.Lines 242–244 split
hidden_statesinto chunks and immediately concatenate them back. Unless there's a side effect or a planned insertion between split and cat, this can be removed.- split_sizes = (grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() - hidden_states = torch.split(hidden_states, split_sizes) - hidden_states = torch.cat(hidden_states, dim=0)If this is a placeholder for future logic, add a comment explaining it.
src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py (2)
64-64: Remove or document commented-out code.Lines 64 and 76 contain commented-out
model_comm_pgsparameters without any explanation. As per coding guidelines, commented-out code should include a comment describing its purpose and why it's commented out, or be removed before merging.- # model_comm_pgs: ModelCommProcessGroups = None, ... - # model_comm_pgs=model_comm_pgs,Also applies to: 76-76
336-339: O(n) index lookupself.deepstack_visual_indexes.index(l_no)on every layer forward pass.Both in the checkpointed path (Line 134) and the non-checkpointed path (Line 337),
.index(l_no)performs a linear scan ofdeepstack_visual_indexesfor every matching layer. Consider converting to adictmappinglayer_no → deepstack_idxin__init__for O(1) lookup.Proposed optimization in __init__
self.deepstack_visual_indexes = config.deepstack_visual_indexes + self._deepstack_index_map = { + l_no: idx for idx, l_no in enumerate(config.deepstack_visual_indexes) + }Then use
self._deepstack_index_map[l_no]instead ofself.deepstack_visual_indexes.index(l_no).
|
|
||
| from megatron.bridge.recipes.qwen_vl.qwen3_vl import qwen3_vl_30b_a3b_pretrain_config | ||
| from megatron.bridge.training.pretrain import pretrain | ||
| from megatron.bridge.training.vlm_step import forward_step |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if vlm_step.forward_step handles Qwen3VL models (look for visual_inputs injection)
echo "=== Checking vlm_step.py for visual_inputs handling ==="
rg -n 'visual_inputs' src/megatron/bridge/training/vlm_step.py
echo ""
echo "=== Checking qwen3vl_step.py for Qwen3VL-specific logic ==="
rg -n 'Qwen3VL\|qwen3vl\|qwen3_vl' src/megatron/bridge/training/qwen3vl_step.py | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 680
🏁 Script executed:
#!/bin/bash
# Check if qwen3vl_step.py exists
echo "=== Looking for qwen3vl_step.py ==="
find . -name "*qwen3vl_step*" -o -name "*qwen3_vl_step*"
echo ""
echo "=== Check what files exist in training module ==="
ls -la src/megatron/bridge/training/ | grep -E "step\.py|qwen"
echo ""
echo "=== If qwen3vl_step exists, show its full content ==="
if [ -f "src/megatron/bridge/training/qwen3vl_step.py" ]; then
wc -l src/megatron/bridge/training/qwen3vl_step.py
cat -n src/megatron/bridge/training/qwen3vl_step.py
fiRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 18492
Import forward_step from qwen3vl_step instead of vlm_step.
This example imports forward_step from vlm_step, which is a generic VLM module. However, since this is a Qwen3-VL specific pretraining script, it should use from megatron.bridge.training.qwen3vl_step import forward_step instead.
qwen3vl_step contains Qwen3-VL–specific optimizations that vlm_step lacks:
- Assertion that model is Qwen3-VL (line 266-267)
- Data format handling tailored for Qwen3-VL (bshd vs thd)
- Position_ids calculation in model forward (not in forward_step)
- Packed sequence parameter handling specific to Qwen3-VL
- Proper multi-modal input injection for pixel_values and image_grid_thw
Using the generic vlm_step.forward_step would miss these optimizations and potentially cause training to diverge from the intended Qwen3-VL behavior.
🤖 Prompt for AI Agents
In `@examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py` at line 37,
Replace the import of forward_step from the generic vlm_step with the Qwen3-VL
specific implementation: change the import to use
megatron.bridge.training.qwen3vl_step so the script uses forward_step
implemented in qwen3vl_step; this ensures the Qwen3-VL assertions, data-format
handling (bshd vs thd), position_ids logic, packed sequence handling, and
multimodal input injection (pixel_values and image_grid_thw) are applied instead
of the generic vlm_step behavior.
| from megatron.bridge.training.qwen3vl_step import forward_step | ||
| from megatron.bridge.utils.common_utils import get_rank_safe | ||
|
|
||
| from functools import partial |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, verify if qwen3vl_step.py has an assertion for Qwen3-VL models
rg -n 'assert.*[Qq]wen' --type=py -C3 src/megatron/bridge/training/qwen3vl_step.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 389
🏁 Script executed:
# Check if partial is actually used anywhere in finetune_qwen_vl.py
rg -n 'partial' examples/recipes/qwen_vl/finetune_qwen_vl.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 104
🏁 Script executed:
# View the imports section of finetune_qwen_vl.py to verify ordering
head -120 examples/recipes/qwen_vl/finetune_qwen_vl.py | tail -40Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1489
🏁 Script executed:
# Check how the script determines recipe_name and model_family
rg -n 'recipe_name|model_family' examples/recipes/qwen_vl/finetune_qwen_vl.py -A2 -B2Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 833
🏁 Script executed:
# Verify if vlm_step.py exists and has a forward_step compatible with Qwen2.5-VL
find . -name "vlm_step.py" -type fRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 113
🏁 Script executed:
# Check what forward_step is in vlm_step.py if it exists
fd -e py vlm_step | head -5 | xargs catRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 20280
🏁 Script executed:
# Check if there are any other uses of qwen25_vl_recipes in the file to understand full context
rg -n 'qwen25_vl_recipes|Qwen2.5' examples/recipes/qwen_vl/finetune_qwen_vl.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 699
Qwen2.5-VL recipes will fail: forward_step is unconditionally imported from qwen3vl_step, which asserts model is Qwen3-VL.
The script supports both Qwen2.5-VL and Qwen3-VL (line 19, recipe selection lines 204–211), but forward_step is imported from qwen3vl_step before main() runs. The qwen3vl_step.forward_step asserts the model is Qwen3-VL (line 267), causing Qwen2.5-VL recipes to fail.
Move the import into main() after the recipe/model_family is determined and select the appropriate module:
if model_family == "Qwen3-VL":
from megatron.bridge.training.qwen3vl_step import forward_step
else:
from megatron.bridge.training.vlm_step import forward_stepAdditionally, remove from functools import partial on line 108—it is unused and is a stdlib import incorrectly placed after first-party imports, violating import ordering guidelines.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 108-108: 'functools.partial' imported but unused
(F401)
🤖 Prompt for AI Agents
In `@examples/recipes/qwen_vl/finetune_qwen_vl.py` around lines 106 - 108, The
file unconditionally imports forward_step from
megatron.bridge.training.qwen3vl_step which asserts the model is Qwen3-VL and
breaks Qwen2.5-VL runs; move that import into main() after you determine
recipe/model_family and conditionally import forward_step from the correct
module (if model_family == "Qwen3-VL" import from
megatron.bridge.training.qwen3vl_step else import from
megatron.bridge.training.vlm_step) so the assertion isn’t triggered at import
time, and remove the unused from functools import partial import which is
misplaced after first-party imports.
| from megatron.core.transformer.attention import * | ||
|
|
||
| from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import apply_rotary_pos_emb_absolute |
There was a problem hiding this comment.
Replace the star import with explicit dependencies (ruff F403/F405).
The wildcard import is causing undefined-name lint errors. Import the exact symbols used in this file to keep the namespace explicit and pass linting.
🛠️ Proposed fix (explicit imports)
-from megatron.core.transformer.attention import *
+from typing import Optional, Tuple, Union
+
+from torch import Tensor
+from megatron.core.inference.contexts import BaseInferenceContext
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.transformer.attention import (
+ HAVE_FA3,
+ SelfAttention,
+ is_fa_min_version,
+ rearrange,
+)
+from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push🧰 Tools
🪛 Ruff (0.15.0)
[error] 16-16: from megatron.core.transformer.attention import * used; unable to detect undefined names
(F403)
| # process groups | ||
| self.pg_collection = pg_collection | ||
| self.cp_group = pg_collection.cp | ||
| self.tp_group = pg_collection.tp | ||
| self.pp_group = pg_collection.pp | ||
| assert hasattr(self.pg_collection, "embd"), ( |
There was a problem hiding this comment.
Guard against pg_collection=None before dereferencing.
pg_collection is optional in the signature but is immediately accessed (pg_collection.cp). If callers omit it, this will raise.
🛠️ Proposed fix
- # process groups
- self.pg_collection = pg_collection
+ if pg_collection is None:
+ pg_collection = get_pg_collection(self)
+ # process groups
+ self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py` around lines
109 - 114, Guard against pg_collection being None before dereferencing: check if
pg_collection is not None before accessing pg_collection.cp/tp/pp and assigning
self.pg_collection, self.cp_group, self.tp_group, self.pp_group; if
pg_collection is None set those group attributes to None (or appropriate
defaults) and adjust the assert (or replace it with an explicit check) so you
only call hasattr(self.pg_collection, "embd") when self.pg_collection is not
None. Ensure you update any code paths that assume these groups exist to handle
the None/default case.
| # Slightly modified from Qwen3VLModel.get_rope_index | ||
| def get_rope_index( | ||
| spatial_merge_size: int, | ||
| image_token_id: int, | ||
| video_token_id: int, | ||
| vision_start_token_id: int, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| image_grid_thw: Optional[torch.LongTensor] = None, | ||
| video_grid_thw: Optional[torch.LongTensor] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" | ||
|
|
||
| def forward( | ||
| self, | ||
| position_ids: torch.Tensor, | ||
| mrope_section: List[int] = None, | ||
| packed_seq_params: Optional[PackedSeqParams] = None, | ||
| **kwargs, | ||
| ) -> Tensor: | ||
| """Forward pass for non-MoE Qwen3-VL RoPE. | ||
| # Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split | ||
| if video_grid_thw is not None: | ||
| video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) | ||
| video_grid_thw[:, 0] = 1 | ||
|
|
||
| Args: | ||
| position_ids: Position IDs tensor | ||
| mrope_section: Optional mrope section (if not provided, uses self.mrope_section) | ||
| """ | ||
| if mrope_section is None: | ||
| mrope_section = self.mrope_section | ||
|
|
||
| if position_ids.ndim == 2: | ||
| position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) | ||
| inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) | ||
| position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) | ||
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) | ||
| freqs = self.apply_interleaved_mrope(freqs, mrope_section) | ||
| emb = torch.cat((freqs, freqs), dim=-1) | ||
| emb = emb[..., None, :].transpose(0, 1).contiguous() | ||
| _ = packed_seq_params # packed sequences not supported yet | ||
| return emb | ||
| if packed_seq_params is not None and attention_mask is None and input_ids is not None: | ||
| # Build an attention mask from packed sequence metadata when one is not provided. | ||
| # cu_seqlens_q entries are cumulative lengths; their diffs give per-sample lengths. | ||
| cu_seqlens = packed_seq_params.cu_seqlens_q | ||
| if cu_seqlens is not None and cu_seqlens.numel() >= 2: | ||
| seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] | ||
| attention_mask = torch.zeros_like(input_ids, dtype=input_ids.dtype) | ||
| max_len = attention_mask.shape[1] | ||
| for i, seq_len in enumerate(seq_lens.tolist()): | ||
| valid = min(int(seq_len), max_len) | ||
| attention_mask[i, :valid] = 1 | ||
| else: | ||
| # Fallback to a dense mask if packed metadata is missing. | ||
| attention_mask = torch.ones_like(input_ids) | ||
|
|
||
| mrope_position_deltas = [] | ||
| if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): | ||
| total_input_ids = input_ids | ||
| if attention_mask is None: | ||
| attention_mask = torch.ones_like(total_input_ids) | ||
| position_ids = torch.ones( | ||
| 3, | ||
| input_ids.shape[0], | ||
| input_ids.shape[1], | ||
| dtype=input_ids.dtype, | ||
| device=input_ids.device, | ||
| ) | ||
| image_index, video_index = 0, 0 | ||
| attention_mask = attention_mask.to(total_input_ids.device) | ||
| for i, input_ids in enumerate(total_input_ids): | ||
| input_ids = input_ids[attention_mask[i] == 1] | ||
| image_nums, video_nums = 0, 0 | ||
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) | ||
| vision_tokens = input_ids[vision_start_indices + 1] | ||
| image_nums = (vision_tokens == image_token_id).sum() | ||
| video_nums = (vision_tokens == video_token_id).sum() | ||
| input_tokens = input_ids.tolist() | ||
| llm_pos_ids_list: list = [] | ||
| st = 0 | ||
| remain_images, remain_videos = image_nums, video_nums | ||
| for _ in range(image_nums + video_nums): | ||
| if image_token_id in input_tokens and remain_images > 0: | ||
| ed_image = input_tokens.index(image_token_id, st) | ||
| else: | ||
| ed_image = len(input_tokens) + 1 | ||
| if video_token_id in input_tokens and remain_videos > 0: | ||
| ed_video = input_tokens.index(video_token_id, st) | ||
| else: | ||
| ed_video = len(input_tokens) + 1 | ||
| if ed_image < ed_video: | ||
| t, h, w = ( | ||
| image_grid_thw[image_index][0], | ||
| image_grid_thw[image_index][1], | ||
| image_grid_thw[image_index][2], | ||
| ) | ||
| image_index += 1 | ||
| remain_images -= 1 | ||
| ed = ed_image | ||
|
|
||
| else: | ||
| t, h, w = ( | ||
| video_grid_thw[video_index][0], | ||
| video_grid_thw[video_index][1], | ||
| video_grid_thw[video_index][2], | ||
| ) | ||
| video_index += 1 | ||
| remain_videos -= 1 | ||
| ed = ed_video | ||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t.item(), | ||
| h.item() // spatial_merge_size, | ||
| w.item() // spatial_merge_size, | ||
| ) | ||
| text_len = ed - st | ||
|
|
||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||
|
|
||
| # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) | ||
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | ||
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) | ||
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w | ||
|
|
||
| if st < len(input_tokens): | ||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| text_len = len(input_tokens) - st | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||
|
|
||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) | ||
| mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) | ||
| mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) | ||
| return position_ids, mrope_position_deltas | ||
| else: | ||
| if attention_mask is not None: | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 1) | ||
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) | ||
| max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] | ||
| mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] | ||
| else: | ||
| position_ids = ( | ||
| torch.arange(input_ids.shape[1], device=input_ids.device) | ||
| .view(1, 1, -1) | ||
| .expand(3, input_ids.shape[0], -1) | ||
| ) | ||
| mrope_position_deltas = torch.zeros( | ||
| [input_ids.shape[0], 1], | ||
| device=input_ids.device, | ||
| dtype=input_ids.dtype, | ||
| ) | ||
|
|
||
| return position_ids, mrope_position_deltas | ||
|
|
There was a problem hiding this comment.
get_rope_index has multiple likely runtime errors (tensor→int, shadowing, None-handling).
Key blockers:
image_nums/video_numsare tensors (.sum()), then used inrange(image_nums + video_nums)→ TypeError.- The loop
for i, input_ids in enumerate(total_input_ids):shadows theinput_idsargument, making the function harder to reason about and risking bugs. - In the “no image/video grids” branch,
input_idscan beNonewhileattention_maskis alsoNone→input_ids.shape[...]will crash. vision_start_indices + 1can go OOB if a<vision_start>token appears at the last position.
Targeted fix sketch (minimal)
def get_rope_index(
@@
- if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
@@
- for i, input_ids in enumerate(total_input_ids):
- input_ids = input_ids[attention_mask[i] == 1]
+ for i, sample_input_ids in enumerate(total_input_ids):
+ sample_input_ids = sample_input_ids[attention_mask[i] == 1]
@@
- vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
- vision_tokens = input_ids[vision_start_indices + 1]
- image_nums = (vision_tokens == image_token_id).sum()
- video_nums = (vision_tokens == video_token_id).sum()
- input_tokens = input_ids.tolist()
+ vision_start_indices = torch.argwhere(sample_input_ids == vision_start_token_id).squeeze(1)
+ vision_start_indices = vision_start_indices[vision_start_indices + 1 < sample_input_ids.numel()]
+ vision_tokens = sample_input_ids[vision_start_indices + 1]
+ image_nums = int((vision_tokens == image_token_id).sum().item())
+ video_nums = int((vision_tokens == video_token_id).sum().item())
+ input_tokens = sample_input_ids.tolist()
@@
- remain_images, remain_videos = image_nums, video_nums
+ remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
@@
- mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
- mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ mrope_position_deltas.append(llm_positions.max() + 1 - total_input_ids.shape[1])
+ mrope_position_deltas = torch.stack(mrope_position_deltas).unsqueeze(1)
@@
- else:
+ else:
+ if attention_mask is None and input_ids is None:
+ raise ValueError("input_ids must be provided when attention_mask is None")| model = Qwen3VLModel( | ||
| language_transformer_config=language_transformer_config, | ||
| language_transformer_layer_spec=language_transformer_layer_spec, | ||
| vision_transformer_config=hf_vision_config, | ||
| vision_transformer_layer_spec=vision_transformer_layer_spec, | ||
| vision_patch_merger_spec=vision_patch_merger_spec, | ||
| pre_process=pre_process, | ||
| post_process=post_process, | ||
| pg_collection=self._pg_collection, | ||
| add_encoder=self.add_encoder, | ||
| add_decoder=self.add_decoder, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's check the structure of the file under review and its parent classes
cat -n src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | head -50Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2534
🏁 Script executed:
# Search for _pg_collection, add_encoder, add_decoder in qwen3 directory
rg -n '_pg_collection|add_encoder|add_decoder' --type=py -C2 src/megatron/bridge/models/qwen3/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 144
🏁 Script executed:
# Check if these attributes are defined in the qwen3_vl_provider file itself
rg -n 'class.*Provider|def __init__|self._pg_collection|self.add_encoder|self.add_decoder' --type=py src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 465
🏁 Script executed:
# Find where Qwen3ModelProvider and Qwen3MoEModelProvider are defined
find src/megatron/bridge/models -name "*.py" -type f | head -20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1080
🏁 Script executed:
# Search for Qwen3ModelProvider definition
rg -n 'class Qwen3ModelProvider|class Qwen3MoEModelProvider' --type=py src/megatron/bridge/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1105
🏁 Script executed:
# Check the imports in qwen3_vl_provider.py to understand where these come from
head -50 src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2184
🏁 Script executed:
# Check the parent classes Qwen3ModelProvider and Qwen3MoEModelProvider
sed -n '257,283p' src/megatron/bridge/models/qwen/qwen_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 940
🏁 Script executed:
# Check Qwen3MoEModelProvider
sed -n '364,400p' src/megatron/bridge/models/qwen/qwen_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1281
🏁 Script executed:
# Check what GPTModelProvider (the grandparent) defines - look for __init__ and these attributes
rg -n 'class GPTModelProvider|def __init__|self._pg_collection|self.add_encoder|self.add_decoder' --type=py src/megatron/bridge/models/ -A5 | grep -E '(class GPTModelProvider|__init__|_pg_collection|add_encoder|add_decoder)' | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3403
🏁 Script executed:
# Find GPTModelProvider definition and its __init__
rg -n 'class GPTModelProvider' --type=py src/megatron/bridge/models/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 195
🏁 Script executed:
# Get the full definition of GPTModelProvider including __init__
sed -n '1,100p' src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3938
🏁 Script executed:
# Search for _pg_collection and add_encoder/add_decoder in GPTModelProvider
rg -n '_pg_collection|add_encoder|add_decoder' --type=py -B2 -A2 src/megatron/bridge/models/gpt_provider.py | head -60Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2117
🏁 Script executed:
# Search for add_encoder and add_decoder definitions in provider classes
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check ModelProviderMixin and TransformerConfig for these attributes
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/model_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check the full Qwen3VLModelProvider class definition to see if it defines these
sed -n '47,200p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 6127
🏁 Script executed:
# Search for add_encoder and add_decoder class attributes or properties in the entire bridge models
rg -n 'add_encoder\s*[:=]|add_decoder\s*[:=]' --type=py src/megatron/bridge/models/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1698
🏁 Script executed:
# Check if they might be properties or methods
rg -n '@property' --type=py -A3 src/megatron/bridge/models/gpt_provider.py | grep -E '(property|add_encoder|add_decoder)'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Search in transformer_config.py or model_provider.py
rg -n 'add_encoder|add_decoder' --type=py src/megatron/bridge/models/transformer_config.py src/megatron/bridge/models/model_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check if add_encoder and add_decoder are defined as properties or anywhere in parent classes
rg -n 'class TransformerConfig|class ModelProviderMixin' --type=py src/megatron/bridge/models/ -A20 | grep -E '(class|add_encoder|add_decoder|@property)'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 534
🏁 Script executed:
# Look at the parent class definitions more closely
sed -n '1,300p' src/megatron/bridge/models/model_provider.py | head -100Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3973
🏁 Script executed:
# Compare with nemotron_vl_provider to see how it handles these parameters
sed -n '120,140p' src/megatron/bridge/models/nemotron_vl/nemotron_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1106
🏁 Script executed:
# Check if add_encoder and add_decoder are defined as dataclass fields in Qwen3VLModelProvider or parents
sed -n '47,200p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n 'add_encoder\|add_decoder'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 162
🏁 Script executed:
# Check the full Qwen3ModelProvider definition for these fields
rg -n '@dataclass' --type=py -A100 src/megatron/bridge/models/qwen/qwen_provider.py | grep -E '(dataclass|add_encoder|add_decoder)'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 427
🏁 Script executed:
# Check if these are defined in TransformerConfig or GPTModelProvider
sed -n '100,250p' src/megatron/bridge/models/gpt_provider.py | grep -n 'add_encoder\|add_decoder'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Look at the full GPTModelProvider class, especially around the provide method
sed -n '127,350p' src/megatron/bridge/models/gpt_provider.py | grep -n 'add_encoder\|add_decoder\|class\|def provide'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 321
🏁 Script executed:
# Check what dataclass fields are defined in GPTModelProvider itself
sed -n '127,280p' src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 7495
🏁 Script executed:
# Check the complete provide method in GPTModelProvider to see how add_encoder/add_decoder are used
sed -n '215,380p' src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 8034
🏁 Script executed:
# Get a comprehensive view - search for where add_encoder/add_decoder might be defined as fields
rg -n '@dataclass|add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.py | head -40Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 175
🏁 Script executed:
# Final confirmation: search comprehensively for where add_encoder and add_decoder are defined as dataclass fields
rg -n 'add_encoder\s*:|add_decoder\s*:' --type=py src/megatron/bridge/models/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 662
🏁 Script executed:
# Check the complete Qwen3VLModelProvider class definition to confirm these fields are not there
sed -n '47,300p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n 'add_encoder\|add_decoder'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 162
🏁 Script executed:
# Check if there's an __init__ method in Qwen3VLModelProvider that initializes them
sed -n '47,300p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | grep -n '__init__'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Let's verify one more time if Qwen3VLModelProvider or its parents define add_encoder/add_decoder as fields
grep -n '@dataclass\|class Qwen3VLModelProvider\|add_encoder\|add_decoder' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py | head -20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 334
🏁 Script executed:
# Check complete parent class hierarchy for these fields
rg -n '@dataclass' --type=py -A150 src/megatron/bridge/models/qwen/qwen_provider.py | grep -E '(class Qwen3ModelProvider|add_encoder|add_decoder)' | head -10Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 436
🏁 Script executed:
# One final check - see if these might be properties defined somewhere
rg -n '@property' --type=py -A3 src/megatron/bridge/models/gpt_provider.py | grep -B1 -A2 'add_encoder\|add_decoder'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check if add_encoder and add_decoder might be inherited from TransformerConfig
sed -n '50,150p' src/megatron/bridge/models/transformer_config.py | grep -n 'add_encoder\|add_decoder\|@dataclass'Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 97
🏁 Script executed:
# One more comprehensive check - look at all occurrences in the provider hierarchy
rg -n 'class.*Provider|add_encoder|add_decoder' --type=py src/megatron/bridge/models/gpt_provider.py src/megatron/bridge/models/transformer_config.py | head -30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 717
🏁 Script executed:
# Look at the nemotron example to see how it handles this (hardcoded True)
sed -n '100,145p' src/megatron/bridge/models/nemotron_vl/nemotron_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2328
🏁 Script executed:
# Let's view the complete Qwen3VLModelProvider class definition to ensure we're not missing anything
sed -n '46,185p' src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 5431
add_encoder and add_decoder are not defined in the provider class hierarchy.
These attributes are referenced at lines 155-156 (and lines 324-325 in the MoE variant) but are never initialized as dataclass fields in Qwen3VLModelProvider, Qwen3ModelProvider, or their parents. The _pg_collection attribute is properly inherited from GPTModelProvider, but add_encoder and add_decoder must be defined. Either add them as dataclass fields with appropriate defaults or pass hardcoded values like nemotron_vl_provider.py does.
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py` around lines 146 -
157, Qwen3VLModelProvider (and Qwen3ModelProvider) reference attributes
add_encoder and add_decoder when constructing Qwen3VLModel but those attributes
are never defined; fix by adding dataclass fields add_encoder: bool = False and
add_decoder: bool = False to the provider class (or its parent) so they're
initialized, or alternatively pass explicit boolean literals into the
Qwen3VLModel constructor where add_encoder/add_decoder are currently used
(mimicking nemotron_vl_provider.py); update both the Qwen3VLModelProvider and
the MoE variant locations that reference add_encoder/add_decoder to use the new
fields or hardcoded values.
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
| if "visual_inputs" in batch: | ||
| # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw" | ||
| # TODO(jinliangl): add video support | ||
| multi_modal_inputs = batch.get("visual_inputs").normalized_for_model() | ||
| else: |
There was a problem hiding this comment.
Guard against visual_inputs=None before calling normalized_for_model().
If the dataset provides visual_inputs=None (e.g., text-only samples), the current code will raise.
🛠️ Proposed fix
- if "visual_inputs" in batch:
- # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw"
- # TODO(jinliangl): add video support
- multi_modal_inputs = batch.get("visual_inputs").normalized_for_model()
- else:
- multi_modal_inputs = {}
+ visual_inputs = batch.get("visual_inputs")
+ if visual_inputs is not None:
+ # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw"
+ # TODO(jinliangl): add video support
+ multi_modal_inputs = visual_inputs.normalized_for_model()
+ else:
+ multi_modal_inputs = {}🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/qwen3vl_step.py` around lines 136 - 140, The
code currently calls batch.get("visual_inputs").normalized_for_model() without
checking for None; update the block around where "visual_inputs" is handled so
you first assign visual = batch.get("visual_inputs") and only call
visual.normalized_for_model() if visual is not None, otherwise set
multi_modal_inputs to None (or skip creating it) so downstream uses of
multi_modal_inputs are guarded; modify the logic in qwen3vl_step.py where
multi_modal_inputs is created to reference the new visual variable and handle
the None case safely.
| if cfg.model.p2p_communicator.is_pp_last_stage: | ||
| # Average loss across microbatches. |
There was a problem hiding this comment.
Guard last-stage loss reduction when p2p_communicator is unset.
cfg.model.p2p_communicator can be None when PP size is 1; accessing .is_pp_last_stage will raise. Fall back to is_pp_last_stage(pg_collection.pp).
🛠️ Proposed fix
- if cfg.model.p2p_communicator.is_pp_last_stage:
+ p2p = cfg.model.p2p_communicator
+ if (p2p is not None and p2p.is_pp_last_stage) or (p2p is None and is_pp_last_stage(pg_collection.pp)):🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/train.py` around lines 739 - 740, Guard the
last-stage loss reduction by checking whether cfg.model.p2p_communicator is not
None before accessing .is_pp_last_stage; if it is None (PP size 1), call
is_pp_last_stage(pg_collection.pp) instead. Update the conditional that
currently reads cfg.model.p2p_communicator.is_pp_last_stage to first test
cfg.model.p2p_communicator and then fall back to
is_pp_last_stage(pg_collection.pp) so the loss averaging path only runs when the
correct determination of last stage succeeds.
| # Copied from verl/verl/models/mcore/util.py | ||
| from megatron.core import mpu |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Move import to the top of the file and drop the "Copied from" comment.
from megatron.core import mpu should be with the other imports (after line 19). A bare "Copied from" comment doesn't explain why the code was duplicated — per coding guidelines, comments should describe usage and rationale.
Additionally, both preprocess_packed_seqs and postprocess_packed_seqs are near-duplicates of the versions in src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py (with the addition of use_fp8_padding). Consider consolidating into a single shared location to avoid drift.
Proposed import fix
import torch
import math
from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core import mpu
...
-# Copied from verl/verl/models/mcore/util.py
-from megatron.core import mpu🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py` around lines 73 - 74,
Move the line "from megatron.core import mpu" up into the module import block
with the other imports (remove the "// Copied from ..." comment entirely), and
ensure the import is colocated with the top-level imports used by this file;
then refactor to avoid duplication by either importing the existing
implementations of preprocess_packed_seqs and postprocess_packed_seqs from the
shared implementation in the qwen_vl utils module or extract a new shared
utility module and have both locations import from it (preserve the
use_fp8_padding behavior when consolidating).
| if pre_process: | ||
| input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) | ||
| for i in range(batch_size): | ||
| # Use Python int, so no GPU→CPU sync in the loop | ||
| if cp_size <= 1: | ||
| seqlen = seqlens_in_batch_cpu[i] | ||
| start_idx = cu_seqlens_padded_cpu[i] | ||
| input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]] | ||
| continue | ||
|
|
||
| seqlen_padded_i = seqlens_in_batch_padded_cpu[i] | ||
| seqlen = seqlen_padded_i // cp_size | ||
| half_seqlen = seqlen // 2 | ||
| start_idx = cu_seqlens_padded_cpu[i] // cp_size | ||
| # split to 2 chunks | ||
| d = input_ids[i, attention_mask[i]] | ||
| input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ | ||
| half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) | ||
| ] | ||
|
|
||
| remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) | ||
| remain_end = seqlen_padded_i - half_seqlen * cp_rank | ||
| remain_end = min(remain_end, d.shape[0]) | ||
| remain_len = remain_end - remain_start | ||
| if remain_len > 0: | ||
| input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ | ||
| remain_start:remain_end | ||
| ] |
There was a problem hiding this comment.
Bug: attention_mask[i] crashes when attention_mask is None.
Lines 88-92 handle attention_mask is None for the seqlens_in_batch computation, but the loop body on lines 137 and 145 unconditionally indexes attention_mask[i] for boolean masking. When attention_mask is None and pre_process is True, this will raise TypeError: 'NoneType' object is not subscriptable.
Proposed fix
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
# Use Python int, so no GPU→CPU sync in the loop
if cp_size <= 1:
seqlen = seqlens_in_batch_cpu[i]
start_idx = cu_seqlens_padded_cpu[i]
- input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
+ if attention_mask is not None:
+ input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]]
+ else:
+ input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, :seqlen]
continue
seqlen_padded_i = seqlens_in_batch_padded_cpu[i]
seqlen = seqlen_padded_i // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded_cpu[i] // cp_size
# split to 2 chunks
- d = input_ids[i, attention_mask[i]]
+ d = input_ids[i, attention_mask[i]] if attention_mask is not None else input_ids[i, :seqlens_in_batch_cpu[i]]🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/utils/packed_seq_utils.py` around lines 130 -
157, The loop unconditionally uses attention_mask[i] causing a crash when
attention_mask is None; update the loop in the pre_process branch to guard use
of attention_mask: compute d = input_ids[i, attention_mask[i]] if attention_mask
is not None, otherwise set d = input_ids[i, :seqlen] for the cp_size <= 1 branch
(use seqlen from seqlens_in_batch_cpu[i]) and d = input_ids[i, :seqlen_padded_i]
for the cp_size > 1 branch (use seqlen_padded_i from
seqlens_in_batch_padded_cpu[i]) before the subsequent slicing that writes into
input_ids_rmpad so indexing never tries to subscript None; adjust references
around input_ids_rmpad, seqlens_in_batch_cpu, seqlens_in_batch_padded_cpu,
cu_seqlens_padded_cpu, cp_size and cp_rank.
c86973a to
adb12c4
Compare
adb12c4 to
c86973a
Compare
c86973a to
3382cce
Compare
3382cce to
2f55a42
Compare
|
Is there anyone can help to review this pr? |
maybe @yaoyu-33 ? |
|
|
||
| _pg_collection: Optional[ProcessGroupCollection] = None | ||
|
|
||
| # vision model type will be used to override the vision model config. |
There was a problem hiding this comment.
bit weird add here the distrain configs and vision configs. find a better approach, this shouldn't be exposed to all gpt_provider models, only qwen3_vl, you can edit this in qwen3 provider
| This method calculates the data parallel size needed by setup methods, without | ||
| triggering full validation or finalization of Megatron Core configs. | ||
| """ | ||
| if hasattr(self.model, "use_dist_train") and self.model.use_dist_train: |
There was a problem hiding this comment.
similar here, it's a bit over hi-jacking the workflow, from distrain.
There was a problem hiding this comment.
maybe create a new method for set_data_parallel_size for disttrain, feels more clean.
| # Distributed - ensure data_parallel_size is calculated (might already be set by set_data_parallel_size) | ||
| if not hasattr(self, "data_parallel_size") or self.data_parallel_size is None: | ||
| world_size = get_world_size_safe() | ||
| if hasattr(self.model, "use_dist_train") and self.model.use_dist_train: |
There was a problem hiding this comment.
mainly see if possible to not do
"hasattr(self.model, "use_dist_train")"
it adds a overhead from a corner use case to the mainstream of validate. It's okay to put all use_distrain_stuff all together in one block instead of adding conditions to current mainstream path.
ask cursor to come up with some ideas here.
| pg_collection_dict: A dictionary mapping module names to ProcessGroupCollections. | ||
| """ | ||
|
|
||
| def finish_mpu_init() -> ProcessGroupCollection: |
There was a problem hiding this comment.
does the return needs to be updated?
"-> ProcessGroupCollection:"
| An optional callable to finish MPU initialization if skip_mpu_initialization | ||
| or lazy_mpu_init is True, otherwise None. | ||
| pg_collection: The process group collection initialized for this run. | ||
| grid_dict: A dictionary mapping module names to HyperCommGrids. |
There was a problem hiding this comment.
explain when grid_dict and pg_collection_dict wouldn't be none in docstring.
| dp = torch.distributed.get_world_size() // (tp * pp * cp) | ||
| print(f"> initialized HyperCommGrid with tp={tp}, pp={pp}, cp={cp}, dp={dp}") | ||
| return pg_collection | ||
| if hasattr(model_config, "use_dist_train") and model_config.use_dist_train: |
There was a problem hiding this comment.
overall, I feel use_distrain is a train config? i.e. cfg.train.use_dist_train. Check where we set fsdp? model config seems weird, but it might be okay
There was a problem hiding this comment.
this section is a bit too long to create pgs for multimodal. offload to a internal util func. maybe "_create_dist_train_pgs"
| vision_model_type: Optional[str] = None | ||
|
|
||
| # parameters for DistTrain | ||
| use_dist_train: bool = False |
| # parameters for DistTrain | ||
| use_dist_train: bool = False | ||
| dist_train_vision_chunk_size: Optional[int] = 1 | ||
| vision_world_size: Optional[int] = None |
There was a problem hiding this comment.
maybe dist_train_dataclass
| num_floating_point_operations_model = flop_utils.num_floating_point_operations(config, batch_size=1) | ||
| p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) | ||
| if config.model.use_dist_train: | ||
| p2p_communicator = config.model.p2p_communicator |
There was a problem hiding this comment.
if it's not a real config in config, something will not be saved during checkpointing, use _ prefix, set it as a private attr.
"config.model._p2p_communicator"
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
This pr is for DistTrain.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Updates
Bug Fixes