Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
vision_patch_merger_spec,
pre_process=True,
post_process=True,
pg_collection=pg_collection,
)

self.language_model = Qwen3VLGPTModel(
Expand Down Expand Up @@ -180,6 +181,16 @@ def __init__(
"Qwen3-VL model only supports context parallelism with calculate_per_token_loss enabled"
)

# Expose position_embedding_type, rotary_pos_emb, and decoder for CUDA graph helper compatibility
# The CUDA graph helper expects model.position_embedding_type, model.rotary_pos_emb, and model.decoder,
# but in Qwen3VL these are nested under language_model. This provides direct access.
# Expose these attributes for CUDA graph helper compatibility only when CUDA graph is enabled
cuda_graph_enabled = getattr(self.language_model.config, "cuda_graph_impl", "none") != "none"
if cuda_graph_enabled:
self.position_embedding_type = self.language_model.position_embedding_type
self.rotary_pos_emb = self.language_model.rotary_pos_emb
self.decoder = self.language_model.decoder

def shared_embedding_or_output_weight(self):
"""This is a convenience method to surface the language model's word embeddings, which is
necessary for `finalize_model_grads._allreduce_word_embedding_grads`."""
Expand Down
7 changes: 5 additions & 2 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
rotary_interleaved: bool = False,
seq_len_interpolation_factor: Optional[float] = None,
rotary_base: int = 10000,
cp_group: torch.distributed.ProcessGroup = None,
) -> None:
super().__init__()

Expand All @@ -69,6 +70,8 @@ def __init__(

# default mrope section is [24, 20, 20], if no mrope section is provided, use default mrope section
self.mrope_section = [24, 20, 20]
assert cp_group is not None, "cp_group is required"
self.cp_group = cp_group

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Expand Down Expand Up @@ -127,10 +130,10 @@ def forward(

# shape (seq_length, bs, 1, 2 * dim)
emb = emb[..., None, :].transpose(0, 1).contiguous()
if parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format:
if self.cp_group.size() > 1 and not self.is_thd_format:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group())
emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group)
return emb


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
cp_group=self.pg_collection.cp,
)
self.mrope_section = self.config.mrope_section
assert self.mrope_section is not None, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Optional, Union

import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core import tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp8Recipe
Expand Down Expand Up @@ -50,6 +50,7 @@

from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import Qwen3VLVisionPatchMerger
from megatron.core.process_groups_config import ProcessGroupCollection


class Qwen3VLVisionTransformerBlock(TransformerBlock):
Expand All @@ -67,6 +68,7 @@ def __init__(
# model_comm_pgs: ModelCommProcessGroups = None,
vp_stage: Optional[int] = None,
patch_merger_spec: ModuleSpec = None,
pg_collection: Optional[torch.distributed.ProcessGroup] = None,
):
assert post_process and pre_process, "not support pp for deepstack_merger_list"
super().__init__(
Expand All @@ -77,7 +79,13 @@ def __init__(
post_process=post_process,
# model_comm_pgs=model_comm_pgs,
vp_stage=vp_stage,
pg_collection=pg_collection,
)
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp
Comment on lines +56 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, verify the file exists and read the relevant section
cat -n src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py | head -100

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4545


🏁 Script executed:

# Search for usages of Qwen3VLVisionTransformerBlock to see how it's instantiated
rg "Qwen3VLVisionTransformerBlock" --type py -B 2 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2769


🏁 Script executed:

# Check if TransformerBlock has any validation or pg_collection handling
rg "class TransformerBlock" --type py -A 20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Search for how pg_collection is used in similar patterns
rg "pg_collection\s*=\s*None" --type py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Get the full instantiation context
rg "self.decoder = Qwen3VLVisionTransformerBlock" --type py -A 10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1228


🏁 Script executed:

# Check TransformerBlock parent class from megatron.core
fd "transformer_block.py" --type f | grep -E "megatron.*core"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Search for any validation of pg_collection in the codebase
rg "pg_collection.*is.*None|if pg_collection" --type py -B 2 -A 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4394


🏁 Script executed:

# Find the Qwen3VLVisionModel class and see how pg_collection is set
rg "class Qwen3VLVisionModel" --type py -A 50 | head -100

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5839


🏁 Script executed:

# Check if there's parent class initialization that sets pg_collection
rg "self.pg_collection\s*=" --type py -B 3 -A 1 src/megatron/bridge/models/qwen_vl/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2367


🏁 Script executed:

# Look at the model_provider.py pattern to understand how pg_collection should be handled
cat -n src/megatron/bridge/models/model_provider.py | grep -A 10 "pg_collection is None"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 698


Add validation for pg_collection parameter before dereferencing.

The pg_collection parameter defaults to None (line 67) but is dereferenced immediately at lines 81-83 without null checking (pg_collection.cp, .tp, .pp). This will raise AttributeError at runtime if the parameter is not provided. Either make pg_collection required (remove the default value), initialize it when None (e.g., via ProcessGroupCollection.use_mpu_process_groups()), or add explicit validation with a meaningful error message.

Note: The parent class Qwen3VLVisionModel has the same issue and should be fixed as well.

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py`
around lines 56 - 83, The constructor Qwen3VLVisionTransformerBlock.__init__
dereferences pg_collection into cp_group/tp_group/pp_group without null
checking; update the init to validate and handle a missing pg_collection by
either requiring it (remove the default None) or initializing it when None
(e.g., call ProcessGroupCollection.use_mpu_process_groups() or another factory)
and then assign self.cp_group/self.tp_group/self.pp_group, or raise a clear
ValueError if pg_collection is None; apply the same null-check/fix pattern to
Qwen3VLVisionModel.__init__ to prevent AttributeError.


self.deepstack_visual_indexes = config.deepstack_visual_indexes
self.deepstack_merger_list = nn.ModuleList(
[
Expand Down Expand Up @@ -141,7 +149,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
attention_mask,
context,
Expand Down Expand Up @@ -503,7 +511,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
attention_mask,
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
self.spatial_merge_size = transformer_config.spatial_merge_size
self.patch_size = transformer_config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.pg_collection = pg_collection
self.tp_group = self.pg_collection.tp

assert transformer_config.context_parallel_size == 1, (
f"context_parallel_size should be 1 in vision model but got {transformer_config.context_parallel_size}"
Expand All @@ -79,6 +81,7 @@ def __init__(
post_process=self.post_process,
post_layer_norm=False,
patch_merger_spec=patch_merger_spec,
pg_collection=self.pg_collection,
)

self.merger = None
Expand Down
1 change: 1 addition & 0 deletions src/megatron/bridge/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def train(
seq_length=config.model.seq_length,
micro_batch_size=config.train.micro_batch_size,
optimizers=[optimizer],
pg_collection=pg_collection,
)

# Track train step elapsed time for throughput logging
Expand Down