diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index 64bc5afa21..3b8d9c5fc4 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -146,6 +146,7 @@ def __init__( vision_patch_merger_spec, pre_process=True, post_process=True, + pg_collection=pg_collection, ) self.language_model = Qwen3VLGPTModel( @@ -180,6 +181,16 @@ def __init__( "Qwen3-VL model only supports context parallelism with calculate_per_token_loss enabled" ) + # Expose position_embedding_type, rotary_pos_emb, and decoder for CUDA graph helper compatibility + # The CUDA graph helper expects model.position_embedding_type, model.rotary_pos_emb, and model.decoder, + # but in Qwen3VL these are nested under language_model. This provides direct access. + # Expose these attributes for CUDA graph helper compatibility only when CUDA graph is enabled + cuda_graph_enabled = getattr(self.language_model.config, "cuda_graph_impl", "none") != "none" + if cuda_graph_enabled: + self.position_embedding_type = self.language_model.position_embedding_type + self.rotary_pos_emb = self.language_model.rotary_pos_emb + self.decoder = self.language_model.decoder + def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py index fe1cf23888..f60e81e5f6 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py @@ -52,6 +52,7 @@ def __init__( rotary_interleaved: bool = False, seq_len_interpolation_factor: Optional[float] = None, rotary_base: int = 10000, + cp_group: torch.distributed.ProcessGroup = None, ) -> None: super().__init__() @@ -69,6 +70,8 @@ def __init__( # default mrope section is [24, 20, 20], if no mrope section is provided, use default mrope section self.mrope_section = [24, 20, 20] + assert cp_group is not None, "cp_group is required" + self.cp_group = cp_group def apply_interleaved_mrope(self, freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. @@ -127,10 +130,10 @@ def forward( # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format: + if self.cp_group.size() > 1 and not self.is_thd_format: # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group()) + emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) return emb diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py index 9b4ec3962c..9914fb3cc0 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py @@ -88,6 +88,7 @@ def __init__( rotary_interleaved=self.config.rotary_interleaved, seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, + cp_group=self.pg_collection.cp, ) self.mrope_section = self.config.mrope_section assert self.mrope_section is not None, ( diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py index aef8988d15..ba4e17e217 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py @@ -22,7 +22,7 @@ from typing import Optional, Union import torch -from megatron.core import parallel_state, tensor_parallel +from megatron.core import tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.enums import Fp8Recipe @@ -50,6 +50,7 @@ from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import Qwen3VLVisionPatchMerger +from megatron.core.process_groups_config import ProcessGroupCollection class Qwen3VLVisionTransformerBlock(TransformerBlock): @@ -67,6 +68,7 @@ def __init__( # model_comm_pgs: ModelCommProcessGroups = None, vp_stage: Optional[int] = None, patch_merger_spec: ModuleSpec = None, + pg_collection: Optional[torch.distributed.ProcessGroup] = None, ): assert post_process and pre_process, "not support pp for deepstack_merger_list" super().__init__( @@ -77,7 +79,13 @@ def __init__( post_process=post_process, # model_comm_pgs=model_comm_pgs, vp_stage=vp_stage, + pg_collection=pg_collection, ) + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + self.deepstack_visual_indexes = config.deepstack_visual_indexes self.deepstack_merger_list = nn.ModuleList( [ @@ -141,7 +149,7 @@ def checkpoint_handler(forward_func): forward_func, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, attention_mask, context, @@ -503,7 +511,7 @@ def checkpoint_handler(forward_func): forward_func, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, attention_mask, context, diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py index 5501d6b7e0..d6bca19665 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py @@ -55,6 +55,8 @@ def __init__( self.spatial_merge_size = transformer_config.spatial_merge_size self.patch_size = transformer_config.patch_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.pg_collection = pg_collection + self.tp_group = self.pg_collection.tp assert transformer_config.context_parallel_size == 1, ( f"context_parallel_size should be 1 in vision model but got {transformer_config.context_parallel_size}" @@ -79,6 +81,7 @@ def __init__( post_process=self.post_process, post_layer_norm=False, patch_merger_spec=patch_merger_spec, + pg_collection=self.pg_collection, ) self.merger = None diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index 514533245e..ca676376c8 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -248,6 +248,7 @@ def train( seq_length=config.model.seq_length, micro_batch_size=config.train.micro_batch_size, optimizers=[optimizer], + pg_collection=pg_collection, ) # Track train step elapsed time for throughput logging