From a456f677ffaa510a3684c05bba859864d382ba36 Mon Sep 17 00:00:00 2001 From: shifangx Date: Fri, 13 Feb 2026 04:36:29 -0800 Subject: [PATCH 1/7] qwen3-vl m4 leftover --- .../models/qwen_vl/modelling_qwen3_vl/model.py | 1 + .../models/qwen_vl/modelling_qwen3_vl/rope.py | 8 +++++--- .../qwen_vl/modelling_qwen3_vl/text_model.py | 1 + .../modelling_qwen3_vl/transformer_block.py | 15 +++++++++++---- .../models/qwen_vl/modelling_qwen3_vl/utils.py | 2 +- .../qwen_vl/modelling_qwen3_vl/vision_model.py | 3 +++ 6 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index 64bc5afa21..505be5989b 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -146,6 +146,7 @@ def __init__( vision_patch_merger_spec, pre_process=True, post_process=True, + pg_collection=pg_collection, ) self.language_model = Qwen3VLGPTModel( diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py index fe1cf23888..bb7c437463 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn -from megatron.core import parallel_state from megatron.core.models.common.embeddings.rope_utils import ( _apply_rotary_pos_emb_bshd, get_pos_emb_on_this_cp_rank, @@ -52,6 +51,7 @@ def __init__( rotary_interleaved: bool = False, seq_len_interpolation_factor: Optional[float] = None, rotary_base: int = 10000, + cp_group: torch.distributed.ProcessGroup = None, ) -> None: super().__init__() @@ -69,6 +69,8 @@ def __init__( # default mrope section is [24, 20, 20], if no mrope section is provided, use default mrope section self.mrope_section = [24, 20, 20] + assert cp_group is not None, "cp_group is required" + self.cp_group = cp_group def apply_interleaved_mrope(self, freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. @@ -127,10 +129,10 @@ def forward( # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - if parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format: + if self.cp_group.size() > 1 and not self.is_thd_format: # slice rotary_pos_emb along sequence dimension and select the parition of the current # CP rank - emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group()) + emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) return emb diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py index 9b4ec3962c..9914fb3cc0 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py @@ -88,6 +88,7 @@ def __init__( rotary_interleaved=self.config.rotary_interleaved, seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, + cp_group=self.pg_collection.cp, ) self.mrope_section = self.config.mrope_section assert self.mrope_section is not None, ( diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py index aef8988d15..6d71d07114 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py @@ -22,7 +22,7 @@ from typing import Optional, Union import torch -from megatron.core import parallel_state, tensor_parallel +from megatron.core import tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.enums import Fp8Recipe @@ -67,6 +67,7 @@ def __init__( # model_comm_pgs: ModelCommProcessGroups = None, vp_stage: Optional[int] = None, patch_merger_spec: ModuleSpec = None, + pg_collection: Optional[torch.distributed.ProcessGroup] = None, ): assert post_process and pre_process, "not support pp for deepstack_merger_list" super().__init__( @@ -77,7 +78,13 @@ def __init__( post_process=post_process, # model_comm_pgs=model_comm_pgs, vp_stage=vp_stage, + pg_collection=pg_collection, ) + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + self.deepstack_visual_indexes = config.deepstack_visual_indexes self.deepstack_merger_list = nn.ModuleList( [ @@ -141,7 +148,7 @@ def checkpoint_handler(forward_func): forward_func, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, attention_mask, context, @@ -391,7 +398,7 @@ def sharded_state_dict( layer_prefix = f"{prefix}layers." num_layers = self.config.num_layers for layer in self.layers: - offset = get_transformer_layer_offset(self.config, self.vp_stage) + offset = get_transformer_layer_offset(self.config, self.vp_stage, pp_rank=self.pp_group.rank()) global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." # module list index in TransformerBlock # pylint: disable=line-too-long @@ -503,7 +510,7 @@ def checkpoint_handler(forward_func): forward_func, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.tp_group, hidden_states, attention_mask, context, diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py index 9af8384819..b714d0f74b 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py @@ -128,7 +128,7 @@ def __init__( self.input_size = config.hidden_size if self.use_postshuffle_norm: self.input_size = self.hidden_size - tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False) + self.tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False) self.patch_norm = build_module( submodules.patch_norm, diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py index 5501d6b7e0..d6bca19665 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py @@ -55,6 +55,8 @@ def __init__( self.spatial_merge_size = transformer_config.spatial_merge_size self.patch_size = transformer_config.patch_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.pg_collection = pg_collection + self.tp_group = self.pg_collection.tp assert transformer_config.context_parallel_size == 1, ( f"context_parallel_size should be 1 in vision model but got {transformer_config.context_parallel_size}" @@ -79,6 +81,7 @@ def __init__( post_process=self.post_process, post_layer_norm=False, patch_merger_spec=patch_merger_spec, + pg_collection=self.pg_collection, ) self.merger = None From 9df8c964f0956d6e0e52ea029231e628cded9361 Mon Sep 17 00:00:00 2001 From: Shifang Xu Date: Sat, 14 Feb 2026 06:58:07 -0800 Subject: [PATCH 2/7] create pg groups and pass to Qwen3VLModel --- src/megatron/bridge/models/gpt_provider.py | 13 ++ .../qwen_vl/modelling_qwen3_vl/model.py | 2 +- .../models/qwen_vl/qwen3_vl_provider.py | 4 + src/megatron/bridge/training/config.py | 15 ++- src/megatron/bridge/training/initialize.py | 127 +++++++++++++----- src/megatron/bridge/training/setup.py | 32 ++++- .../bridge/training/utils/train_utils.py | 12 ++ 7 files changed, 166 insertions(+), 39 deletions(-) diff --git a/src/megatron/bridge/models/gpt_provider.py b/src/megatron/bridge/models/gpt_provider.py index 085daac409..e14ffd1c45 100644 --- a/src/megatron/bridge/models/gpt_provider.py +++ b/src/megatron/bridge/models/gpt_provider.py @@ -198,6 +198,19 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]): _pg_collection: Optional[ProcessGroupCollection] = None + # parameters for DistTrain + use_dist_train: bool = False + dist_train_vision_chunk_size: Optional[int] = 1 + vision_world_size: Optional[int] = None + language_world_size: Optional[int] = None + vision_tensor_model_parallel_size: Optional[int] = None + vision_pipeline_model_parallel_size: Optional[int] = None + vision_context_parallel_size: Optional[int] = None + vision_expert_tensor_parallel_size: Optional[int] = None + vision_expert_model_parallel_size: Optional[int] = None + add_vision_module: bool = True + add_language_module: bool = True + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: """Configure and instantiate a Megatron Core GPT model based on this configuration. diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index 566b0526d9..f147c3d447 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -90,7 +90,7 @@ def __init__( self.post_process = post_process self.add_encoder = add_encoder self.add_decoder = add_decoder - + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.__init__(), self.config.use_dist_train: {self.config.use_dist_train}, self.add_encoder: {self.add_encoder}, self.add_decoder: {self.add_decoder}") self.encoder_hidden_state = None self.vision_model = None self.language_model = None diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py index c7371f8670..13cbba372b 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py @@ -125,6 +125,8 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL pre_process=pre_process, post_process=post_process, pg_collection=self._pg_collection, + add_encoder=self.add_encoder, + add_decoder=self.add_decoder, ) # Apply freeze options if any are enabled for fine-tuning @@ -276,6 +278,8 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VL pre_process=pre_process, post_process=post_process, pg_collection=self._pg_collection, + add_encoder=self.add_encoder, + add_decoder=self.add_decoder, ) # Apply freeze options if any are enabled for fine-tuning diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 9d1ae10dc5..406ffff013 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -1391,8 +1391,13 @@ def set_data_parallel_size(self) -> None: This method calculates the data parallel size needed by setup methods, without triggering full validation or finalization of Megatron Core configs. """ + if hasattr(self.model, "use_dist_train") and self.model.use_dist_train: + # use language world size to calculate data parallel size + world_size = self.model.language_world_size + else: + # use total world size to calculate data parallel size + world_size = get_world_size_safe() # Calculate data parallel size (needed for comm overlap setup) - world_size = get_world_size_safe() self.data_parallel_size = self.get_data_parallel_size(world_size) # Set data_parallel_size on comm_overlap config if present @@ -1464,7 +1469,13 @@ def validate(self) -> None: # Distributed - ensure data_parallel_size is calculated (might already be set by set_data_parallel_size) if not hasattr(self, "data_parallel_size") or self.data_parallel_size is None: - world_size = get_world_size_safe() + if hasattr(self.model, "use_dist_train") and self.model.use_dist_train: + # use language world size to calculate data parallel size + world_size = self.model.language_world_size + else: + # use total world size to calculate data parallel size + world_size = get_world_size_safe() + # Calculate data parallel size (needed for comm overlap setup) self.data_parallel_size = self.get_data_parallel_size(world_size) # Set data_parallel_size on comm_overlap config if present if self.comm_overlap is not None: diff --git a/src/megatron/bridge/training/initialize.py b/src/megatron/bridge/training/initialize.py index 5229906b0e..b2942d938a 100644 --- a/src/megatron/bridge/training/initialize.py +++ b/src/megatron/bridge/training/initialize.py @@ -16,7 +16,8 @@ import os import time import warnings -from typing import Callable, Optional +from copy import copy +from typing import Callable, Optional, Tuple import torch import torch.distributed @@ -43,6 +44,7 @@ from megatron.bridge.models import GPTModelProvider, T5ModelProvider from megatron.bridge.training.config import ConfigContainer, DistributedInitConfig, RerunStateMachineConfig, RNGConfig +from megatron.bridge.training.utils.train_utils import is_rank_in_pg from megatron.bridge.utils.common_utils import ( get_local_rank_preinit, get_master_addr_safe, @@ -51,7 +53,6 @@ get_world_size_safe, ) - def initialize_megatron( cfg: ConfigContainer, allow_no_cuda: bool = False, @@ -159,7 +160,7 @@ def torch_dist_init( skip_mpu_initialization: bool, restart_store: Optional[torch.distributed.Store] = None, use_inprocess_restart: bool = False, -) -> Callable[[], None] | ProcessGroupCollection | None: +) -> Tuple[ProcessGroupCollection, Optional[dict[str, HyperCommGrid]], Optional[dict[str, ProcessGroupCollection]]]: """Initialize torch.distributed and dependent components. Handles the core distributed setup, including process group initialization, @@ -177,13 +178,14 @@ def torch_dist_init( skip_mpu_initialization: If True, returns a function to finish MPU setup later. Returns: - An optional callable to finish MPU initialization if skip_mpu_initialization - or lazy_mpu_init is True, otherwise None. + pg_collection: The process group collection initialized for this run. + grid_dict: A dictionary mapping module names to HyperCommGrids. + pg_collection_dict: A dictionary mapping module names to ProcessGroupCollections. """ def finish_mpu_init() -> ProcessGroupCollection: # Pytorch distributed. - pg_collection = _initialize_distributed( + pg_collection, grid_dict, pg_collection_dict = _initialize_distributed( model_config=model_config, dist_config=dist_config, num_distributed_optimizer_instances=num_distributed_optimizer_instances, @@ -207,10 +209,10 @@ def finish_mpu_init() -> ProcessGroupCollection: if model_config.num_moe_experts is not None: MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device())) - return pg_collection + return pg_collection, grid_dict, pg_collection_dict if skip_mpu_initialization: - return None + return None, None, None if dist_config.lazy_init: # delayed initialization of DDP-related stuff @@ -219,14 +221,14 @@ def finish_mpu_init() -> ProcessGroupCollection: # and return function for external DDP manager # to call when it has DDP initialized parallel_state.set_tensor_model_parallel_rank(get_rank_safe()) - return finish_mpu_init + return finish_mpu_init, None, None # Megatron's MPU is the master. Complete initialization right away. - pg_collection = finish_mpu_init() + pg_collection, grid_dict, pg_collection_dict = finish_mpu_init() if model_config.tp_comm_overlap: _initialize_tp_communicators(model_config, micro_batch_size) - return pg_collection + return pg_collection, grid_dict, pg_collection_dict def init_rerun_state(rerun_state_machine_config: RerunStateMachineConfig) -> None: @@ -386,9 +388,14 @@ def _create_pg_collection( num_distributed_optimizer_instances: int, get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, -) -> ProcessGroupCollection: - """Create all process groups via HyperCommGrid and return a ProcessGroupCollection.""" - world_size = torch.distributed.get_world_size() + world_size: int = None, + rank_offset: int = None, +) -> Tuple[ProcessGroupCollection, HyperCommGrid]: + """Create all process groups via HyperCommGrid and return a ProcessGroupCollection and a HyperCommGrid.""" + if world_size is None: + world_size = torch.distributed.get_world_size() + if rank_offset is None: + rank_offset = 0 tp_size = int(model_config.tensor_model_parallel_size) pp_size = int(model_config.pipeline_model_parallel_size) cp_size = int(model_config.context_parallel_size) if getattr(model_config, "context_parallel_size", 1) else 1 @@ -400,7 +407,7 @@ def _create_pg_collection( grid = HyperCommGrid( shape=[tp_size, cp_size, dp_size, pp_size], dim_names=["tp", "cp", "dp", "pp"], - rank_offset=0, + rank_offset=rank_offset, backend="nccl", ) # Core groups @@ -440,7 +447,7 @@ def _create_pg_collection( expert_grid = HyperCommGrid( shape=[expert_tp_size, ep_size, inner_expt_dp_size, num_distributed_optimizer_instances, pp_size], dim_names=["tp", "ep", "inner_dp", "outer_dp", "pp"], - rank_offset=0, + rank_offset=rank_offset, backend="nccl", ) dp_group_dims: list[str] = ["inner_dp", "outer_dp"] @@ -450,7 +457,7 @@ def _create_pg_collection( expert_grid = HyperCommGrid( shape=[expert_tp_size, ep_size, expt_dp_size, pp_size], dim_names=["tp", "ep", "dp", "pp"], - rank_offset=0, + rank_offset=rank_offset, backend="nccl", ) dp_group_dims = ["dp"] @@ -522,7 +529,7 @@ def _create_pg_collection( inter_dist_opt=inter_dist_opt_pg, intra_dist_opt=intra_dist_opt_pg, ) - return pg_collection + return pg_collection, grid def _initialize_distributed( @@ -533,8 +540,13 @@ def _initialize_distributed( get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]], restart_store: Optional[torch.distributed.Store] = None, use_inprocess_restart: bool = False, -) -> ProcessGroupCollection: - """Initialize torch.distributed and core model parallel.""" +) -> Tuple[ProcessGroupCollection, Optional[dict[str, HyperCommGrid]], Optional[dict[str, ProcessGroupCollection]]]: + """Initialize torch.distributed and core model parallel. + Returns: + pg_collection: The process group collection initialized for this run. + grid_dict: A dictionary mapping module names to HyperCommGrids. + pg_collection_dict: A dictionary mapping module names to ProcessGroupCollections. + """ device_count = torch.cuda.device_count() if torch.distributed.is_initialized(): @@ -597,19 +609,66 @@ def _initialize_distributed( # Use HyperCommGrid to create local parallel groups passed through functions # instead of relying on mcore's global parallel state (mpu) variables. parallel_state._set_global_memory_buffer() - pg_collection = _create_pg_collection( - model_config, - num_distributed_optimizer_instances, - get_embedding_ranks=get_embedding_ranks, - get_position_embedding_ranks=get_position_embedding_ranks, - ) - if get_rank_safe() == 0: - tp = int(model_config.tensor_model_parallel_size) - pp = int(model_config.pipeline_model_parallel_size) - cp = int(model_config.context_parallel_size) if getattr(model_config, "context_parallel_size", 1) else 1 - dp = torch.distributed.get_world_size() // (tp * pp * cp) - print(f"> initialized HyperCommGrid with tp={tp}, pp={pp}, cp={cp}, dp={dp}") - return pg_collection + if hasattr(model_config, "use_dist_train") and model_config.use_dist_train: + vision_model_config = copy(model_config) + vision_model_config.world_size = model_config.vision_world_size + vision_model_config.rank_offset = 0 + vision_model_config.tensor_model_parallel_size = model_config.vision_tensor_model_parallel_size + vision_model_config.pipeline_model_parallel_size = model_config.vision_pipeline_model_parallel_size + vision_model_config.context_parallel_size = model_config.vision_context_parallel_size + vision_model_config.expert_tensor_parallel_size = model_config.vision_expert_tensor_parallel_size + vision_model_config.expert_model_parallel_size = model_config.vision_expert_model_parallel_size + language_model_config = copy(model_config) + language_model_config.world_size = model_config.language_world_size + language_model_config.rank_offset = model_config.vision_world_size + vision_pg_collection, vision_grid = _create_pg_collection( + vision_model_config, + num_distributed_optimizer_instances, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + world_size=model_config.vision_world_size, + rank_offset=0, + ) + language_pg_collection, language_grid = _create_pg_collection( + language_model_config, + num_distributed_optimizer_instances, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + world_size=model_config.language_world_size, + rank_offset=model_config.vision_world_size, + ) + grid_dict = {'vision_module': vision_grid, 'language_module': language_grid} + pg_collection_dict = {'vision_module': vision_pg_collection, 'language_module': language_pg_collection} + assert is_rank_in_pg(language_pg_collection) or is_rank_in_pg(vision_pg_collection), f"Rank {get_rank_safe()} should be in either the language or vision process group collection." + assert not (is_rank_in_pg(language_pg_collection) and is_rank_in_pg(vision_pg_collection)), f"Rank {get_rank_safe()} should not be in both the language and vision process group collection." + pg_collection = vision_pg_collection if is_rank_in_pg(vision_pg_collection) else language_pg_collection + if get_rank_safe() == 0: + tp = int(vision_model_config.tensor_model_parallel_size) + pp = int(vision_model_config.pipeline_model_parallel_size) + cp = int(vision_model_config.context_parallel_size) if getattr(vision_model_config, "context_parallel_size", 1) else 1 + dp = vision_model_config.vision_world_size // (tp * pp * cp) + print(f"> initialized HyperCommGrid for vision model with tp={tp}, pp={pp}, cp={cp}, dp={dp}") + tp = int(language_model_config.tensor_model_parallel_size) + pp = int(language_model_config.pipeline_model_parallel_size) + cp = int(language_model_config.context_parallel_size) if getattr(language_model_config, "context_parallel_size", 1) else 1 + dp = language_model_config.language_world_size // (tp * pp * cp) + print(f"> initialized HyperCommGrid for language model with tp={tp}, pp={pp}, cp={cp}, dp={dp}") + else: + pg_collection, grid = _create_pg_collection( + model_config, + num_distributed_optimizer_instances, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + ) + grid_dict = {'language_module': grid} + pg_collection_dict = {'language_module': pg_collection} + if get_rank_safe() == 0: + tp = int(model_config.tensor_model_parallel_size) + pp = int(model_config.pipeline_model_parallel_size) + cp = int(model_config.context_parallel_size) if getattr(model_config, "context_parallel_size", 1) else 1 + dp = torch.distributed.get_world_size() // (tp * pp * cp) + print(f"> initialized HyperCommGrid with tp={tp}, pp={pp}, cp={cp}, dp={dp}") + return pg_collection, grid_dict, pg_collection_dict else: # Use the original mcore parallel_state.initialize_model_parallel approach if parallel_state.model_parallel_is_initialized(): @@ -645,7 +704,7 @@ def _initialize_distributed( f"{parallel_state.get_pipeline_model_parallel_world_size()}" ) # Return a ProcessGroupCollection using mpu process groups - return ProcessGroupCollection.use_mpu_process_groups() + return ProcessGroupCollection.use_mpu_process_groups(), None, None def _set_random_seed( diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index c2081c29c1..b463c69838 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -27,6 +27,7 @@ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.rerun_state_machine import RerunDataIterator from megatron.core.transformer import MegatronModule +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator from megatron.core.process_groups_config import ProcessGroupCollection from megatron.bridge.data.loaders import setup_data_iterators @@ -49,7 +50,7 @@ finalize_tensor_inspect_post_model_initialization, initialize_tensor_inspect_pre_model_initialization, ) - +from megatron.bridge.training.utils.train_utils import is_rank_in_pg class SetupOutput(NamedTuple): @@ -133,12 +134,39 @@ def setup( # pg_collection is returned from initialize_megatron: # - When use_decentralized_pg=True: uses HyperCommGrid to create local process groups # - When use_decentralized_pg=False: uses mpu's global parallel state - pg_collection = initialize_megatron( + pg_collection, grid_dict, pg_collection_dict = initialize_megatron( cfg=cfg, get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, restart_store=restart_store, ) + cfg.model.pg_collection = pg_collection + # cfg.model.grid_dict = grid_dict + # cfg.model.pg_collection_dict = pg_collection_dict + if hasattr(cfg.model, "use_dist_train") and cfg.model.use_dist_train: + assert pg_collection is not None, "pg_collection is required" + assert pg_collection_dict['language_module'] is not None, "pg_collection for language module is required" + assert pg_collection_dict['vision_module'] is not None, "pg_collection for vision module is required" + assert grid_dict is not None, "grid_dict is required" + topology = { + 'vision_module': ['language_module'], # vision_module sends forward results to language_module + 'language_module': [], # language_module is the last stage here + } + # Create multimodule communicator + cfg.model.p2p_communicator = MultiModulePipelineCommunicator( + grid_dict, topology, cfg.model, dim_mapping={'b': 0, 's': 1, 'h': 2} + ) + + if is_rank_in_pg(pg_collection_dict['vision_module']): + cfg.model.add_encoder = True + # TODO(shifang): support different data parallel size for vision and language model in the future. + cfg.dataset.num_images = int(cfg.dataset.num_images/cfg.model.dist_train_vision_chunk_size) + else: + cfg.model.add_encoder = False + if is_rank_in_pg(pg_collection_dict['language_module']): + cfg.model.add_decoder = True + else: + cfg.model.add_decoder = False # Set CPU affinity for optimal host-device transfers when fine-grained activation offloading is enabled if cfg.model.fine_grained_activation_offloading: diff --git a/src/megatron/bridge/training/utils/train_utils.py b/src/megatron/bridge/training/utils/train_utils.py index df297a8a32..e76367d5f0 100644 --- a/src/megatron/bridge/training/utils/train_utils.py +++ b/src/megatron/bridge/training/utils/train_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import inspect import math import os @@ -24,6 +25,7 @@ import torch import torch.nn as nn from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import track_moe_metrics @@ -1095,3 +1097,13 @@ def maybe_inject_state( return partial(forward_step_func, state) else: return forward_step_func + + +def is_rank_in_pg(pg_collection: ProcessGroupCollection) -> bool: + """Check if the current rank is in the process group collection.""" + current_rank = get_rank_safe() + for field in dataclasses.fields(pg_collection): + pg = getattr(pg_collection, field.name, None) + if pg and current_rank in torch.distributed.get_process_group_ranks(pg): + return True + return False From 54c79bf4ac8266aad7c4aef03ac6d6946a50617b Mon Sep 17 00:00:00 2001 From: Shifang Xu Date: Sat, 14 Feb 2026 18:04:28 -0800 Subject: [PATCH 3/7] model init and forward --- .../qwen_vl/modelling_qwen3_vl/model.py | 125 ++++++++++++------ 1 file changed, 82 insertions(+), 43 deletions(-) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index f147c3d447..a2db10e106 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -21,6 +21,7 @@ ) from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.pipeline_parallel.utils import is_pp_last_stage from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec @@ -45,6 +46,7 @@ split_deepstack_embs, ) from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.vision_model import Qwen3VLVisionModel +from typing import List, Dict class Qwen3VLModel(MegatronModule): @@ -83,13 +85,19 @@ def __init__( pg_collection: ProcessGroupCollection = None, ) -> None: super().__init__(config=language_transformer_config) + self.language_transformer_config = language_transformer_config language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention self.pre_process = pre_process self.post_process = post_process - self.add_encoder = add_encoder - self.add_decoder = add_decoder + if language_transformer_config.use_dist_train: + self.add_encoder = add_encoder + self.add_decoder = add_decoder + assert not (self.add_encoder and self.add_decoder) and (self.add_encoder or self.add_decoder), "add_encoder and add_decoder should not be both True or both False for dist train" + else: + self.add_encoder = self.pre_process + self.add_decoder = True print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.__init__(), self.config.use_dist_train: {self.config.use_dist_train}, self.add_encoder: {self.add_encoder}, self.add_decoder: {self.add_decoder}") self.encoder_hidden_state = None self.vision_model = None @@ -119,7 +127,7 @@ def __init__( self.vp_stage = None self.vp_size = self.config.virtual_pipeline_model_parallel_size - if self.pre_process: + if self.add_encoder: if language_transformer_config.use_hf_vision_model: raise ValueError("use_hf_vision_model is not supported for Qwen3VLModel for now") vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() @@ -135,6 +143,7 @@ def __init__( ) megatron_vision_transformer_config.pipeline_model_parallel_size = 1 megatron_vision_transformer_config.first_pipeline_num_layers = None + self.vision_transformer_config = megatron_vision_transformer_config self.vision_model = Qwen3VLVisionModel( megatron_vision_transformer_config, @@ -144,33 +153,33 @@ def __init__( post_process=True, pg_collection=pg_collection, ) - - self.language_model = Qwen3VLGPTModel( - config=language_transformer_config, - transformer_layer_spec=language_transformer_layer_spec, - vocab_size=language_transformer_config.vocab_size, - max_sequence_length=language_transformer_config.language_max_sequence_length, - parallel_output=parallel_output, - position_embedding_type="mrope", - rotary_percent=language_transformer_config.rotary_percent, - pre_process=self.pre_process, - post_process=self.post_process, - rotary_base=language_transformer_config.rotary_base, - fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, - share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, - scatter_embedding_sequence_parallel=False, - pg_collection=pg_collection, - ) - if pre_process: - assert len(vision_transformer_config.deepstack_visual_indexes) <= len( - self.language_model.decoder.layers - ), ( - "the deepstack_visual_embeds should on the first pp-stage of language model", - f"got {len(vision_transformer_config.deepstack_visual_indexes)} deepstack_visual_indexes, " - f" {len(self.language_model.decoder.layers)} language model layers", + if self.add_decoder: + self.language_model = Qwen3VLGPTModel( + config=language_transformer_config, + transformer_layer_spec=language_transformer_layer_spec, + vocab_size=language_transformer_config.vocab_size, + max_sequence_length=language_transformer_config.language_max_sequence_length, + parallel_output=parallel_output, + position_embedding_type="mrope", + rotary_percent=language_transformer_config.rotary_percent, + pre_process=self.pre_process, + post_process=self.post_process, + rotary_base=language_transformer_config.rotary_base, + fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, + share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, + scatter_embedding_sequence_parallel=False, + pg_collection=pg_collection, ) + if pre_process: + assert len(vision_transformer_config.deepstack_visual_indexes) <= len( + self.language_model.decoder.layers + ), ( + "the deepstack_visual_embeds should on the first pp-stage of language model", + f"got {len(vision_transformer_config.deepstack_visual_indexes)} deepstack_visual_indexes, " + f" {len(self.language_model.decoder.layers)} language model layers", + ) - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights if self.pg_collection.cp.size() > 1: assert self.config.calculate_per_token_loss, ( @@ -184,17 +193,33 @@ def shared_embedding_or_output_weight(self): return self.language_model.shared_embedding_or_output_weight() return None - def set_input_tensor(self, input_tensor) -> None: - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3VL" - - if self.pre_process: - self.encoder_hidden_state = input_tensor[0] + def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + """Set input tensor for pipeline parallelism. + """ + if self.config.use_dist_train: + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor()") + assert isinstance(input_tensor, list), "Input tensor must be a list" + assert len(input_tensor) == 1, "Input tensor must be a list of length 1" + assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary" + input_dict = input_tensor[0] + + if 'vision_module' in input_dict: + vision_module_output_tensor = input_dict['vision_module'] + num_chunks = len(self.vision_transformer_config.deepstack_visual_indexes) + 1 + chunks = torch.chunk(vision_module_output_tensor, chunks=num_chunks, dim=0) + self.vision_embeds = chunks[-1] + self.deepstack_feature_lists = chunks[:-1] + if 'language_module' in input_dict: + self.language_model.set_input_tensor(input_dict['language_module']) else: - self.language_model.set_input_tensor(input_tensor[0]) + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3VL" + + if self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) def freeze( self, @@ -304,6 +329,7 @@ def forward( ) vision_embeds = None + vision_module_output = None if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: if cp_size > 1 and self.config.vision_dp_when_cp: if cp_img_num is None: @@ -324,10 +350,16 @@ def forward( ) vision_grid_thw = collapse_thw(vision_grid_thw) if vision_data.shape[0] > 0: - vision_embeds, deepstack_feature_lists = self.vision_model( - hidden_states=vision_data, - grid_thw=vision_grid_thw, - ) + if self.vision_model is not None: + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), before vision_model.forward()") + vision_embeds, deepstack_feature_lists = self.vision_model( + hidden_states=vision_data, + grid_thw=vision_grid_thw, + ) + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), after vision_model.forward(), vision_embeds.shape: {vision_embeds.shape}") + else: + vision_embeds = self.vision_embeds + deepstack_feature_lists = self.deepstack_feature_lists else: vision_embeds = torch.zeros( (0, self.language_model.config.hidden_size), @@ -335,7 +367,7 @@ def forward( dtype=torch.bfloat16, ) deepstack_feature_lists = [] - for _ in self.vision_model.config.deepstack_visual_indexes: + for _ in self.vision_transformer_config.deepstack_visual_indexes: deepstack_feature_lists.append( torch.zeros( (0, self.language_model.config.hidden_size), @@ -343,6 +375,10 @@ def forward( dtype=torch.bfloat16, ) ) + vision_module_output = deepstack_feature_lists + vision_module_output.append(vision_embeds) + vision_module_output_tensor = torch.cat(vision_module_output, dim=0) + output_vision_module = {'vision_module': vision_module_output_tensor} if cp_size > 1 and self.config.vision_dp_when_cp: vision_embeds = AllGatherVisionEmbeddings.apply( vision_embeds, @@ -356,6 +392,9 @@ def forward( cp_group=self.pg_collection.cp, ) + if self.language_model is None: + # TODO(shifang): need to handle the case when num_images is 0 for some samples. + return output_vision_module combined_embeddings = self.language_model.embedding( input_ids=input_ids, position_ids=None, # NOTE: disable From 2f55a425c9c1bfcf5c2c7ccedc72eba846e41271 Mon Sep 17 00:00:00 2001 From: shifangx Date: Sun, 15 Feb 2026 04:20:25 -0800 Subject: [PATCH 4/7] hange --- .../qwen_vl/modelling_qwen3_vl/model.py | 30 +++++++++++++------ src/megatron/bridge/training/train.py | 22 +++++++++++--- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index a2db10e106..3fa5abcac9 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -85,6 +85,8 @@ def __init__( pg_collection: ProcessGroupCollection = None, ) -> None: super().__init__(config=language_transformer_config) + self.vision_embeds = None + self.deepstack_feature_lists = None self.language_transformer_config = language_transformer_config language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention @@ -127,7 +129,7 @@ def __init__( self.vp_stage = None self.vp_size = self.config.virtual_pipeline_model_parallel_size - if self.add_encoder: + if self.pre_process: if language_transformer_config.use_hf_vision_model: raise ValueError("use_hf_vision_model is not supported for Qwen3VLModel for now") vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() @@ -144,7 +146,9 @@ def __init__( megatron_vision_transformer_config.pipeline_model_parallel_size = 1 megatron_vision_transformer_config.first_pipeline_num_layers = None self.vision_transformer_config = megatron_vision_transformer_config + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.__init__(), self.vision_transformer_config.deepstack_visual_indexes: {self.vision_transformer_config.deepstack_visual_indexes}") + if self.add_encoder: self.vision_model = Qwen3VLVisionModel( megatron_vision_transformer_config, vision_transformer_layer_spec, @@ -196,11 +200,14 @@ def shared_embedding_or_output_weight(self): def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): """Set input tensor for pipeline parallelism. """ + if input_tensor is None or len(input_tensor) == 0 or input_tensor[0] is None: + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor(), input_tensor is None or len(input_tensor) == 0 or input_tensor[0] is None") + return if self.config.use_dist_train: - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor()") - assert isinstance(input_tensor, list), "Input tensor must be a list" - assert len(input_tensor) == 1, "Input tensor must be a list of length 1" - assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary" + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor(), input_tensor: {input_tensor}") + assert isinstance(input_tensor, list), "Input tensor must be a list, but got {type(input_tensor)}" + assert len(input_tensor) == 1, "Input tensor must be a list of length 1, but got {len(input_tensor)}" + assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary, but got {type(input_tensor[0])}" input_dict = input_tensor[0] if 'vision_module' in input_dict: @@ -356,10 +363,15 @@ def forward( hidden_states=vision_data, grid_thw=vision_grid_thw, ) + vision_module_output = deepstack_feature_lists + vision_module_output.append(vision_embeds) + vision_module_output_tensor = torch.cat(vision_module_output, dim=0) + output_vision_module = {'vision_module': vision_module_output_tensor} print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), after vision_model.forward(), vision_embeds.shape: {vision_embeds.shape}") else: vision_embeds = self.vision_embeds deepstack_feature_lists = self.deepstack_feature_lists + print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), vision_model is None, vision_embeds.shape: {vision_embeds.shape}") else: vision_embeds = torch.zeros( (0, self.language_model.config.hidden_size), @@ -375,10 +387,7 @@ def forward( dtype=torch.bfloat16, ) ) - vision_module_output = deepstack_feature_lists - vision_module_output.append(vision_embeds) - vision_module_output_tensor = torch.cat(vision_module_output, dim=0) - output_vision_module = {'vision_module': vision_module_output_tensor} + if cp_size > 1 and self.config.vision_dp_when_cp: vision_embeds = AllGatherVisionEmbeddings.apply( vision_embeds, @@ -530,5 +539,8 @@ def forward( **kwargs, ) torch.cuda.nvtx.range_pop() + if self.config.use_dist_train: + if not is_pp_last_stage(self.pg_collection.pp): + return {'language_module': output} return output diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index 235d68784b..85a81affde 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -36,8 +36,12 @@ from megatron.core.optimizer.qk_clip import clip_qk from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.parallel_state import update_pg_timeout -from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator -from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +from megatron.core.pipeline_parallel.schedules import ( + get_forward_backward_func, + forward_backward_pipelining_without_interleaving, +) from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, is_pp_last_stage, @@ -244,11 +248,14 @@ def train( if config.logger.log_throughput_to_tensorboard: history_wct = deque(maxlen=config.logger.throughput_window_size + 1) - # Wrap forward_backward_func for Full iteration CUDA graph forward_backward_func = get_forward_backward_func( pp_size=pg_collection.pp.size(), vp_size=config.model.virtual_pipeline_model_parallel_size, ) + if config.model.use_dist_train: + # currently, dist train only support non-interleaved pipeline parallel schedule. + forward_backward_func = forward_backward_pipelining_without_interleaving + # Wrap forward_backward_func for Full iteration CUDA graph if config.model.cuda_graph_impl == "local" and CudaGraphScope.full_iteration in config.model.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper( forward_backward_func, cuda_graph_warmup_steps=config.model.cuda_graph_warmup_steps @@ -258,6 +265,9 @@ def train( print_rank_0(f"Starting training loop at iteration {start_iteration}") num_floating_point_operations_model = flop_utils.num_floating_point_operations(config, batch_size=1) p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) + if config.model.use_dist_train: + p2p_communicator = config.model.p2p_communicator + assert isinstance(p2p_communicator, MultiModulePipelineCommunicator), "p2p_communicator must be a MultiModulePipelineCommunicator while use_dist_train is True" dp_size = pg_collection.dp.size() if should_fire(callback_manager, "on_train_start"): @@ -732,6 +742,10 @@ def train_step( adjust_tensor_shapes_fn = None # Forward pass. + if cfg.model.use_dist_train: + assert isinstance(p2p_communicator, MultiModulePipelineCommunicator), "p2p_communicator must be a MultiModulePipelineCommunicator while use_dist_train is True" + else: + assert isinstance(p2p_communicator, P2PCommunicator), "p2p_communicator must be a P2PCommunicator while use_dist_train is False" losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=forward_backward_data_iterator, @@ -790,7 +804,7 @@ def train_step( if train_config.empty_unused_memory_level >= 2: torch.cuda.empty_cache() - if is_pp_last_stage(pg_collection.pp): + if cfg.model.p2p_communicator.is_pp_last_stage: # Average loss across microbatches. loss_reduced = {} From 4a8492fb1e5ea77dacd00a2290b532e6bdc26e88 Mon Sep 17 00:00:00 2001 From: shifangx Date: Sun, 15 Feb 2026 05:00:49 -0800 Subject: [PATCH 5/7] delete debug message --- .../bridge/models/qwen_vl/modelling_qwen3_vl/model.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index 3fa5abcac9..f25337d92c 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -100,7 +100,6 @@ def __init__( else: self.add_encoder = self.pre_process self.add_decoder = True - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.__init__(), self.config.use_dist_train: {self.config.use_dist_train}, self.add_encoder: {self.add_encoder}, self.add_decoder: {self.add_decoder}") self.encoder_hidden_state = None self.vision_model = None self.language_model = None @@ -146,7 +145,6 @@ def __init__( megatron_vision_transformer_config.pipeline_model_parallel_size = 1 megatron_vision_transformer_config.first_pipeline_num_layers = None self.vision_transformer_config = megatron_vision_transformer_config - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.__init__(), self.vision_transformer_config.deepstack_visual_indexes: {self.vision_transformer_config.deepstack_visual_indexes}") if self.add_encoder: self.vision_model = Qwen3VLVisionModel( @@ -201,10 +199,8 @@ def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): """Set input tensor for pipeline parallelism. """ if input_tensor is None or len(input_tensor) == 0 or input_tensor[0] is None: - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor(), input_tensor is None or len(input_tensor) == 0 or input_tensor[0] is None") return if self.config.use_dist_train: - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.set_input_tensor(), input_tensor: {input_tensor}") assert isinstance(input_tensor, list), "Input tensor must be a list, but got {type(input_tensor)}" assert len(input_tensor) == 1, "Input tensor must be a list of length 1, but got {len(input_tensor)}" assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary, but got {type(input_tensor[0])}" @@ -358,7 +354,6 @@ def forward( vision_grid_thw = collapse_thw(vision_grid_thw) if vision_data.shape[0] > 0: if self.vision_model is not None: - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), before vision_model.forward()") vision_embeds, deepstack_feature_lists = self.vision_model( hidden_states=vision_data, grid_thw=vision_grid_thw, @@ -367,11 +362,9 @@ def forward( vision_module_output.append(vision_embeds) vision_module_output_tensor = torch.cat(vision_module_output, dim=0) output_vision_module = {'vision_module': vision_module_output_tensor} - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), after vision_model.forward(), vision_embeds.shape: {vision_embeds.shape}") else: vision_embeds = self.vision_embeds deepstack_feature_lists = self.deepstack_feature_lists - print(f"for debug, rank {torch.distributed.get_rank()} in Qwen3VLModel.forward(), vision_model is None, vision_embeds.shape: {vision_embeds.shape}") else: vision_embeds = torch.zeros( (0, self.language_model.config.hidden_size), From 4412bcb8fabbf9a58ecbc945196f8fac344c585c Mon Sep 17 00:00:00 2001 From: shifangx Date: Sun, 15 Feb 2026 22:21:39 -0800 Subject: [PATCH 6/7] dist_train_vision_chunk_size --- .../bridge/models/qwen_vl/modelling_qwen3_vl/model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index f25337d92c..9d85c2954b 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -354,6 +354,14 @@ def forward( vision_grid_thw = collapse_thw(vision_grid_thw) if vision_data.shape[0] > 0: if self.vision_model is not None: + if self.config.use_dist_train: + assert cp_size == 1, "currently, dist train does not support context parallelism for encoder" + num_chunks = self.config.dist_train_vision_chunk_size + chunk_idx = self.pg_collection.dp.rank() % num_chunks + vision_data_chunks = torch.chunk(vision_data, chunks=num_chunks, dim=0) + vision_data = vision_data_chunks[chunk_idx] + vision_grid_thw_chunks = torch.chunk(vision_grid_thw, chunks=num_chunks, dim=0) + vision_grid_thw = vision_grid_thw_chunks[chunk_idx] vision_embeds, deepstack_feature_lists = self.vision_model( hidden_states=vision_data, grid_thw=vision_grid_thw, From fbe1dec2db023854deacd3ad6cea727493bfbc18 Mon Sep 17 00:00:00 2001 From: shifangx Date: Sun, 15 Feb 2026 22:36:43 -0800 Subject: [PATCH 7/7] support vision model type vit_2b --- src/megatron/bridge/models/gpt_provider.py | 5 +++++ .../qwen_vl/modelling_qwen3_vl/transformer_config.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/src/megatron/bridge/models/gpt_provider.py b/src/megatron/bridge/models/gpt_provider.py index e14ffd1c45..0e066403ed 100644 --- a/src/megatron/bridge/models/gpt_provider.py +++ b/src/megatron/bridge/models/gpt_provider.py @@ -198,6 +198,11 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]): _pg_collection: Optional[ProcessGroupCollection] = None + # vision model type will be used to override the vision model config. + # if None, the vision model config will be used as is. + # currently, only vit_2b is supported. + vision_model_type: Optional[str] = None + # parameters for DistTrain use_dist_train: bool = False dist_train_vision_chunk_size: Optional[int] = 1 diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py index 2e336b7b6a..d730e5d2c4 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py @@ -67,6 +67,17 @@ def get_vision_model_config(hf_config, megatron_config=None): add_bias_linear=True, add_qkv_bias=True, ) + if megatron_config.vision_model_type == "vit_2b": + hf_config.depth = 45 + hf_config.hidden_size = 1536 + hf_config.num_heads = 16 + hf_config.intermediate_size = 8960 + hf_config.patch_size = 16 + hf_config.spatial_merge_size = 2 + if hasattr(hf_config, "head_dim"): + hf_config.head_dim = 96 + else: + assert megatron_config.vision_model_type is None, ValueError(f"support only vit_2b, but got {config.vision_model_type}") # apply text model config to vision model config config.recompute_granularity = megatron_config.recompute_granularity