diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index bbbedbb9f5..844848b637 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit bbbedbb9f53343762e4dc70abc771b813a83d817 +Subproject commit 844848b637ac5ed386875829226457f0029e06d4 diff --git a/examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py b/examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py new file mode 100644 index 0000000000..bb6c9d9c71 --- /dev/null +++ b/examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +============================================================================== +Example: Qwen3_VL Pretraining with Decentralized Process Groups (Simple) +============================================================================== + +This example demonstrates the simplest way to enable decentralized process groups: +just use an existing recipe and set `cfg.dist.use_decentralized_pg = True`. + +The setup() function inside pretrain() will automatically create the +ProcessGroupCollection using HyperCommGrid based on the parallelism settings. + +How to Run +---------- +# 8 GPUs: EP8 +uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_vl_simple.py +""" + +import torch + +from megatron.bridge.recipes.qwen_vl.qwen3_vl import qwen3_vl_30b_a3b_pretrain_config +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.vlm_step import forward_step + + +def main() -> None: + """Run Qwen3 pretraining with decentralized process groups enabled.""" + # Get the standard Qwen3 4B pretrain config with overrides + cfg = qwen3_vl_30b_a3b_pretrain_config( + # Use mock data for demo + mock=True, + # Parallelism + expert_model_parallel_size=8, + # Training settings (small for demo) + train_iters=100, + seq_length=1024, + global_batch_size=32, + micro_batch_size=1, + # LR schedule (must fit within train_iters) + lr_warmup_iters=10, + lr_decay_iters=100, + ) + # known issue with share_embeddings_and_output_weights + cfg.model.share_embeddings_and_output_weights = False + + # ========================================================================= + # KEY: Enable decentralized process groups + # ========================================================================= + cfg.dist.use_decentralized_pg = True + cfg.dist.use_gloo_process_groups = False # Gloo not supported with decentralized PG + + pretrain(config=cfg, forward_step_func=forward_step) + + # Cleanup + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/recipes/qwen_vl/conf/qwen3_vl_pretrain_override_example.yaml b/examples/recipes/qwen_vl/conf/qwen3_vl_pretrain_override_example.yaml index 5977114679..4318a9d40d 100644 --- a/examples/recipes/qwen_vl/conf/qwen3_vl_pretrain_override_example.yaml +++ b/examples/recipes/qwen_vl/conf/qwen3_vl_pretrain_override_example.yaml @@ -16,12 +16,17 @@ model: seq_length: 4096 + use_hf_vision_model: false + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + cross_entropy_loss_fusion: false train: - train_iters: 20 - global_batch_size: 8 - micro_batch_size: 1 - eval_iters: 5 + train_iters: 1000 + global_batch_size: 16 + micro_batch_size: 2 + eval_iters: 100 optimizer: lr: 0.00025 @@ -40,6 +45,8 @@ dist: logger: log_interval: 1 + log_throughput: true + log_params_norm: true dataset: sequence_length: 4096 @@ -50,4 +57,14 @@ rng: ddp: grad_reduce_in_fp32: true +profiling: + memory_snapshot_path: snapshot.pickle + nvtx_ranges: false + profile_ranks: [0] + profile_step_end: 12 + profile_step_start: 10 + record_memory_history: false + record_shapes: false + use_nsys_profiler: false + use_pytorch_profiler: false diff --git a/examples/recipes/qwen_vl/finetune_qwen_vl.py b/examples/recipes/qwen_vl/finetune_qwen_vl.py index 73cc3517fc..6762c6a212 100644 --- a/examples/recipes/qwen_vl/finetune_qwen_vl.py +++ b/examples/recipes/qwen_vl/finetune_qwen_vl.py @@ -103,9 +103,9 @@ create_omegaconf_dict_config, parse_hydra_overrides, ) -from megatron.bridge.training.vlm_step import forward_step +from megatron.bridge.training.qwen3vl_step import forward_step from megatron.bridge.utils.common_utils import get_rank_safe - +from functools import partial logger: logging.Logger = logging.getLogger(__name__) @@ -185,6 +185,7 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: help="Use preloaded dataset provider (enabled automatically when --data-path is set).", ) parser.add_argument("--debug", action="store_true", help="Enable debug logging") + args, cli_dotlist_overrides = parser.parse_known_args() return args, cli_dotlist_overrides diff --git a/scripts/performance/argument_parser.py b/scripts/performance/argument_parser.py index 9773fb615d..99aa970b5a 100644 --- a/scripts/performance/argument_parser.py +++ b/scripts/performance/argument_parser.py @@ -145,7 +145,7 @@ def parse_cli_args(): parser.add_argument( "--domain", type=lower_str, - choices=["llm", "vlm"], + choices=["llm", "vlm", "qwen3vl"], help="Domain to use for experiment.", default="llm", ) diff --git a/scripts/performance/run_script.py b/scripts/performance/run_script.py index d70dd1e70b..86d9afd7a6 100644 --- a/scripts/performance/run_script.py +++ b/scripts/performance/run_script.py @@ -22,6 +22,7 @@ from megatron.bridge.training.gpt_step import forward_step from megatron.bridge.training.pretrain import pretrain from megatron.bridge.training.vlm_step import forward_step as vlm_forward_step +from megatron.bridge.training.qwen3vl_step import forward_step as qwen3vl_forward_step logger = logging.getLogger(__name__) @@ -61,6 +62,8 @@ def main(): # Select forward step function based on the model family name. if args.domain == "vlm": forward_step_func = vlm_forward_step + elif args.domain == "qwen3vl": + forward_step_func = qwen3vl_forward_step else: forward_step_func = forward_step diff --git a/scripts/performance/utils/overrides.py b/scripts/performance/utils/overrides.py index 4d29293c93..393c24a595 100644 --- a/scripts/performance/utils/overrides.py +++ b/scripts/performance/utils/overrides.py @@ -127,6 +127,46 @@ def _set_cuda_graph_overrides( return recipe +def _set_vision_cuda_graph_overrides( + recipe: ConfigContainer, + vision_cuda_graph_impl: Optional[str] = None, + vision_cuda_graph_scope: Optional[str | List[str]] = None, +) -> ConfigContainer: + """Set the vision encoder CUDA graph overrides. + + This enables TE CUDA graph for the vision encoder separately from the language model. + + Args: + recipe: The config container + vision_cuda_graph_impl: Vision encoder CUDA graph implementation ("none" or "transformer_engine") + vision_cuda_graph_scope: Vision encoder CUDA graph scope (e.g., ["attn"]) + + Returns: + Updated config container + """ + if isinstance(vision_cuda_graph_scope, str): + vision_cuda_graph_scope = [vision_cuda_graph_scope] + + if vision_cuda_graph_impl is not None: + recipe.model.vision_cuda_graph_impl = vision_cuda_graph_impl + + if vision_cuda_graph_impl == "transformer_engine": + # Ensure TE RNG tracker is enabled for CUDA graph compatibility + recipe.rng.te_rng_tracker = recipe.model.use_te_rng_tracker = True + + valid_te_scopes = ["attn", "mlp"] # Vision encoder typically only has attn and mlp + if vision_cuda_graph_scope: + assert all(scope in valid_te_scopes for scope in vision_cuda_graph_scope), ( + f"Invalid vision cuda graph scope: {vision_cuda_graph_scope}. " + f"Valid options for vision encoder are: {valid_te_scopes}" + ) + + if vision_cuda_graph_scope is not None: + recipe.model.vision_cuda_graph_scope = vision_cuda_graph_scope + + return recipe + + def _set_recompute_overrides( recipe: ConfigContainer, cpu_offloading_num_layers: Optional[int] = None, diff --git a/src/megatron/bridge/data/vlm_datasets/hf_provider.py b/src/megatron/bridge/data/vlm_datasets/hf_provider.py index c63c95e780..8014276753 100644 --- a/src/megatron/bridge/data/vlm_datasets/hf_provider.py +++ b/src/megatron/bridge/data/vlm_datasets/hf_provider.py @@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider): # DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.) dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single" + # Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider) + pack_sequences_in_batch: bool = False + def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]: registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = { "make_rdr_dataset": make_rdr_dataset, diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 1684828dcc..5ac282df88 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -1840,7 +1840,7 @@ class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): .. code-block:: python # Create mapping for attention weights - mapping = QKVMapping( + mapping = ConcatenatedQKVMapping( megatron_param="decoder.layers.*.self_attention.linear_qkv.weight", qkv="model.layers.*.self_attn.qkv.weight", ) diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py new file mode 100644 index 0000000000..f7855c5692 --- /dev/null +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024 Alibaba PAI Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.transformer.attention import * + +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import apply_rotary_pos_emb_absolute + + +class Qwen3VLSelfAttention(SelfAttention): + """ + Overrides the SelfAttention class, the difference is that qwen3vl uses apply_rotary_pos_emb_absolute + instead of apply_rotary_pos_emb + """ + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + rotary_pos_cos_sin: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Perform a forward pass through the attention module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary + embedding tensor(s). + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) Attention output and bias. + + """ + # Check if we need to skip RoPE + # no_rope is 0-indexed array and self.layer_number is 1-indexed + no_rope = self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False + if no_rope: + rotary_pos_emb = None + + inference_context = deprecate_inference_params(inference_context, inference_params) + + if inference_context and inference_context.is_dynamic_batching(): + assert HAVE_FA3 or is_fa_min_version("2.7.3"), ( + "flash attn verion v2.7.3 and above is required for dynamic batching." + ) + + # hidden_states: [sq, b, h] + if self.config.flash_decode and not self.training and inference_context is not None: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix="qkv") + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + + in_decode_mode = inference_context is not None and inference_context.is_decode_only() and not self.training + + # This branch only runs in the decode phase of flash decoding and returns after the linear + # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. + nvtx_range_push(suffix="adjust_key_value") + if in_decode_mode and self.config.flash_decode: + assert self.layer_number in inference_context.key_value_memory_dict + assert inference_context.sequence_len_offset is not None + inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number] + output = self.flash_decode( + sequence_len_offset=sequence_len_offset, + query_layer=query, + key_layer=key, + value_layer=value, + inference_key_memory=inference_key_memory, + inference_value_memory=inference_value_memory, + rotary_cos=rotary_pos_cos, + rotary_sin=rotary_pos_sin, + rotary_interleaved=self.config.rotary_interleaved, + ) + out = output.transpose(0, 1).contiguous() + context_layer = out.view(out.size(0), out.size(1), -1) + output, bias = self.linear_proj(context_layer) + return output, bias + + if in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching(): + raise ValueError(f"CUDA graphs must use flash decode with static batching!") + + query, key, value, rotary_pos_emb, attn_mask_type, block_table = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + nvtx_range_pop(suffix="adjust_key_value") + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + nvtx_range_push(suffix="rotary_pos_emb") + if rotary_pos_emb is not None and not self.config.flash_decode: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + if inference_context is None or inference_context.is_static_batching(): + query = apply_rotary_pos_emb_absolute( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + ) + else: + query = inference_context.apply_rotary_emb_query( + query, + q_pos_emb, + self.config, + cu_seqlens_q, + self.model_comm_pgs.cp, + ) + if k_pos_emb is not None: + key = apply_rotary_pos_emb_absolute( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + ) + + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + nvtx_range_pop(suffix="rotary_pos_emb") + + # ================================== + # core attention computation + # ================================== + + nvtx_range_push(suffix="core_attention") + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + if inference_context is None or inference_context.is_static_batching(): + # Static batching attention kernel. + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + else: + # Dynamic batching attention kernel. + q, k, v = (query, key, value) + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill( + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_query_lengths, + cu_kv_lengths, + kv_lengths, + kv_lengths_decode_only, + block_table, + ) + core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)") + + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix="core_attention") + + # ================= + # Output. [sq, b, h] + # ================= + + nvtx_range_push(suffix="linear_proj") + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix="linear_proj") + + return output, bias diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index a99feea993..17549b3f0f 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -13,17 +13,30 @@ # limitations under the License. import torch +from copy import copy from megatron.core import InferenceParams, mpu, tensor_parallel from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec -from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel as Qwen3VLVisionModelHF +from megatron.core.utils import nvtx_range_pop, nvtx_range_push from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLGPTModel +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig as Qwen3VLConfigHF from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig -from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import get_rope_index, split_deepstack_embs -from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import ( + split_deepstack_embs, + reorganize_inputs, + qwen3vl_cp_split, + split_data_cp_rank, + AllGatherVisionEmbeddings, + collapse_thw, + get_vision_cp_data, +) +from megatron.bridge.training.utils.packed_seq_utils import preprocess_packed_seqs +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import get_rope_index +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import Qwen3VLSelfAttention +from megatron.bridge.training.utils.pg_utils import get_pg_collection class Qwen3VLModel(MegatronModule): @@ -54,14 +67,19 @@ def __init__( language_transformer_config: Qwen3VLTransformerConfig, language_transformer_layer_spec: ModuleSpec, vision_transformer_config: Qwen3VLConfigHF, + vision_transformer_layer_spec: ModuleSpec, + vision_patch_merger_spec: ModuleSpec, parallel_output: bool = True, pre_process: bool = True, post_process: bool = True, add_encoder: bool = True, add_decoder: bool = True, + pg_collection: ProcessGroupCollection = None, ) -> None: super().__init__(config=language_transformer_config) + vision_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + self.pre_process = pre_process self.post_process = post_process self.add_encoder = add_encoder @@ -74,18 +92,61 @@ def __init__( self.video_token_id = language_transformer_config.video_token_id self.vision_start_token_id = language_transformer_config.vision_start_token_id + self.square_merge_size = vision_transformer_config.spatial_merge_size**2 + # This attribute is needed to check if an all-reduce is required # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. self.share_embeddings_and_output_weights = False + # process groups + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + assert hasattr(self.pg_collection, "embd"), ( + "pg_collection must have a embd. In previous version, it used default " + "`parallel_state.default_embedding_ranks` to create the process group." + "If you are using the default process group, please use" + "`parallel_state.get_embedding_group()` " + "If you don't need embd_group, you need to explicitly set it to None." + ) + self.embd_group = pg_collection.embd + self.vp_stage = None + self.vp_size = self.config.virtual_pipeline_model_parallel_size if self.pre_process: - # Initialize vision model with random weights from config - self.vision_model = Qwen3VLVisionModelHF._from_config(vision_transformer_config) - # Ensure HF visual tower params are marked for TP grad sync and future assignments are hooked. - hook_hf_module_setattr_for_tp_grad_sync(self.vision_model) - # Move to device if available - if torch.cuda.is_available(): - self.vision_model = self.vision_model.to("cuda") + if not language_transformer_config.use_hf_vision_model: + # use megatron vision model + from .vision_model import Qwen3VLVisionModel + from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import ( + get_vision_model_config, + ) + + megatron_vision_transformer_config = get_vision_model_config(copy(language_transformer_config), vision_transformer_config) + megatron_vision_transformer_config.pipeline_model_parallel_size = 1 + megatron_vision_transformer_config.first_pipeline_num_layers = None + + self.vision_model = Qwen3VLVisionModel( + megatron_vision_transformer_config, + vision_transformer_layer_spec, + vision_patch_merger_spec, + pre_process=True, + post_process=True, + pg_collection=pg_collection, + ) + print(f"rank {torch.distributed.get_rank()} use megatron vision model") + else: + # use hf vision model + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel as Qwen3VLVisionModelHF + from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync + + # Initialize vision model with random weights from config + self.vision_model = Qwen3VLVisionModelHF._from_config(vision_transformer_config) + # Ensure HF visual tower params are marked for TP grad sync and future assignments are hooked. + hook_hf_module_setattr_for_tp_grad_sync(self.vision_model) + # Move to device if available + if torch.cuda.is_available(): + self.vision_model = self.vision_model.to("cuda") + print(f"rank {torch.distributed.get_rank()} use hf vision model") self.language_model = Qwen3VLGPTModel( config=language_transformer_config, @@ -101,6 +162,7 @@ def __init__( fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy, share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights, scatter_embedding_sequence_parallel=False, + pg_collection=pg_collection, ) if pre_process: assert len(vision_transformer_config.deepstack_visual_indexes) <= len( @@ -112,6 +174,36 @@ def __init__( ) self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + self.pg_collection = get_pg_collection(self) + # Expose position_embedding_type, rotary_pos_emb, and decoder for CUDA graph helper compatibility + # The CUDA graph helper expects model.position_embedding_type, model.rotary_pos_emb, and model.decoder, + # but in Qwen3VL these are nested under language_model. This provides direct access. + # Expose these attributes for CUDA graph helper compatibility only when CUDA graph is enabled + cuda_graph_enabled = getattr(self.language_model.config, "cuda_graph_impl", "none") != "none" + if cuda_graph_enabled: + self.position_embedding_type = self.language_model.position_embedding_type + self.rotary_pos_emb = self.language_model.rotary_pos_emb + self.decoder = self.language_model.decoder + + # Expose position_embedding_type, rotary_pos_emb, and decoder for CUDA graph helper compatibility + # The CUDA graph helper expects model.position_embedding_type, model.rotary_pos_emb, and model.decoder, + # but in Qwen3VL these are nested under language_model. This provides direct access. + # Expose these attributes for CUDA graph helper compatibility only when CUDA graph is enabled + cuda_graph_enabled = getattr(self.language_model.config, "cuda_graph_impl", "none") != "none" + if cuda_graph_enabled: + self.position_embedding_type = self.language_model.position_embedding_type + self.rotary_pos_emb = self.language_model.rotary_pos_emb + self.decoder = self.language_model.decoder + + # Expose position_embedding_type, rotary_pos_emb, and decoder for CUDA graph helper compatibility + # The CUDA graph helper expects model.position_embedding_type, model.rotary_pos_emb, and model.decoder, + # but in Qwen3VL these are nested under language_model. This provides direct access. + # Expose these attributes for CUDA graph helper compatibility only when CUDA graph is enabled + cuda_graph_enabled = getattr(self.language_model.config, "cuda_graph_impl", "none") != "none" + if cuda_graph_enabled: + self.position_embedding_type = self.language_model.position_embedding_type + self.rotary_pos_emb = self.language_model.rotary_pos_emb + self.decoder = self.language_model.decoder def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is @@ -140,39 +232,59 @@ def freeze( ): """Freeze model modules. - Make specific modules non-trainable by setting requires_grad to False. + Make specific modules non-trainable by setting requires_grad to False for the module's parameters. Args: freeze_language_model (bool): Freeze the language model module. - freeze_vision_model (bool): Freeze the vision model module (patch_embed, blocks, pos_embed). - freeze_vision_projection (bool): Freeze the vision projection modules (merger and deepstack_merger_list). + freeze_vision_model (bool): Freeze the vision model module. + freeze_vision_projection (bool): Freeze the vision projection module. """ modules = [] - - if freeze_language_model and self.language_model is not None: - modules.append(self.language_model) - - if freeze_vision_model and self.vision_model is not None: - # Freeze vision encoder components (patch_embed, blocks, pos_embed, rotary_pos_emb) - if hasattr(self.vision_model, "patch_embed"): - modules.append(self.vision_model.patch_embed) - if hasattr(self.vision_model, "blocks"): - modules.append(self.vision_model.blocks) - if hasattr(self.vision_model, "pos_embed"): - modules.append(self.vision_model.pos_embed) - if hasattr(self.vision_model, "rotary_pos_emb"): - modules.append(self.vision_model.rotary_pos_emb) - - if freeze_vision_projection and self.vision_model is not None: - # Freeze vision projection components (merger and deepstack_merger_list) - if hasattr(self.vision_model, "merger"): + if not self.config.use_hf_vision_model: + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + if freeze_vision_model and self.vision_model is not None: + modules.append(self.vision_model) + if freeze_vision_projection and self.vision_model is not None: + modules.append(self.vision_model.decoder.deepstack_merger_list) modules.append(self.vision_model.merger) - if hasattr(self.vision_model, "deepstack_merger_list"): - modules.append(self.vision_model.deepstack_merger_list) - for module in modules: - for param in module.parameters(): - param.requires_grad = False + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + if freeze_vision_model and not freeze_vision_projection: + if self.vision_model is not None: + for param in self.vision_model.decoder.deepstack_merger_list.parameters(): + param.requires_grad = True + for param in self.vision_model.merger.parameters(): + param.requires_grad = True + else: + modules = [] + if freeze_language_model and self.language_model is not None: + modules.append(self.language_model) + + if freeze_vision_model and self.vision_model is not None: + # Freeze vision encoder components (patch_embed, blocks, pos_embed, rotary_pos_emb) + if hasattr(self.vision_model, "patch_embed"): + modules.append(self.vision_model.patch_embed) + if hasattr(self.vision_model, "blocks"): + modules.append(self.vision_model.blocks) + if hasattr(self.vision_model, "pos_embed"): + modules.append(self.vision_model.pos_embed) + if hasattr(self.vision_model, "rotary_pos_emb"): + modules.append(self.vision_model.rotary_pos_emb) + + if freeze_vision_projection and self.vision_model is not None: + # Freeze vision projection components (merger and deepstack_merger_list) + if hasattr(self.vision_model, "merger"): + modules.append(self.vision_model.merger) + if hasattr(self.vision_model, "deepstack_merger_list"): + modules.append(self.vision_model.deepstack_merger_list) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False def forward( self, @@ -187,10 +299,20 @@ def forward( pixel_values_videos: torch.Tensor = None, image_grid_thw: torch.Tensor = None, video_grid_thw: torch.Tensor = None, - # cat set at dataset + # can set at dataset image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + cp_img_num: list[int] = None, + images_padded: list[bool] = None, + **kwargs, ) -> torch.Tensor: """Forward function of the Qwen3VL model. + ### there is a workaround for supporting sequence packing with context parallelism + # cp split with sequence packing will make model lose vision token information, so we need to keep + # the original input_ids and pack them after vision embedding is calculated, + # cooporate with verl's models/mcore/model_forward.py + # pack the combined_embeddings to thd here, we check if packed_seq_params is None to determine if we need to pack the combined_embeddings to thd + # this function needs the position_ids and attention_mask in BSHD format, no matter use packed_seq or not Args: image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. @@ -201,43 +323,85 @@ def forward( labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. inference_params (InferenceParams): Inference-time parameters including KV cache. - video_start_index: - 0 -- all video - len(video_seq) -- all image - others -- mixture - *_input_mask: should not be None in the first PP stage Returns: output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. """ - assert pixel_values_videos is None and video_grid_thw is None, "not support video now" assert inference_params is None, "not support inference" - video_start_index = 0 vision_grid_thw = None vision_data = None - image_mask = None + vision_mask = None deepstack_feature_lists = None + # position ids is computed within the model position_ids = None - torch.cuda.nvtx.range_push("Qwen3VLModel.forward.pre_process") + nvtx_range_push(suffix="forward_pre_process") + cp_size = self.pg_collection.cp.size() if self.pre_process: - if image_grid_thw is not None: - image_mask = image_input_mask - if image_mask is None: - image_mask = (input_ids == self.image_token_id).contiguous() - vision_grid_thw = image_grid_thw - vision_data = pixel_values - video_start_index = image_mask.sum().item() - assert video_start_index > 0 + # can reorganize_inputs at dataset + vision_data, vision_grid_thw, vision_mask = reorganize_inputs( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + square_merge_size=self.square_merge_size, + ) vision_embeds = None if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0: - vision_embeds, deepstack_feature_lists = self.vision_model( - hidden_states=vision_data, - grid_thw=vision_grid_thw, - ) + if cp_size > 1: + if cp_img_num is None: + assert images_padded is None + vision_data, vision_grid_thw, cp_img_num, images_padded = qwen3vl_cp_split( + cp_size, + vision_data, + vision_grid_thw, + ) + vision_data, vision_grid_thw, seqlen_on_cp_ranks = get_vision_cp_data( + vision_data, + vision_grid_thw, + self.square_merge_size, + cp_img_num, + images_padded, + ) + vision_grid_thw = collapse_thw(vision_grid_thw) + if vision_data.shape[0] > 0: + vision_embeds, deepstack_feature_lists = self.vision_model( + hidden_states=vision_data, + grid_thw=vision_grid_thw, + ) + else: + vision_embeds = torch.zeros( + (0, self.language_model.config.hidden_size), + device=vision_data.device, + dtype=torch.bfloat16, + ) + deepstack_feature_lists = [] + for _ in self.vision_model.config.deepstack_visual_indexes: + deepstack_feature_lists.append( + torch.zeros( + (0, self.language_model.config.hidden_size), + device=vision_data.device, + dtype=torch.bfloat16, + ) + ) + if cp_size > 1: + vision_embeds = AllGatherVisionEmbeddings.apply( + vision_embeds, + seqlen_on_cp_ranks, + ) + for i in range(len(deepstack_feature_lists)): + deepstack_feature_lists[i] = AllGatherVisionEmbeddings.apply( + deepstack_feature_lists[i], + seqlen_on_cp_ranks, + ) combined_embeddings = self.language_model.embedding( input_ids=input_ids, @@ -245,33 +409,89 @@ def forward( ).clone() # [text_seq_len, b, h_language] if vision_embeds is not None: - if video_start_index == 0: - image_embeds = None - video_embeds = vision_embeds - elif video_start_index == vision_embeds.shape[0]: - image_embeds = vision_embeds - video_embeds = None - elif 0 < video_start_index < vision_embeds.shape[0]: - image_embeds = vision_embeds[:video_start_index] - video_embeds = vision_embeds[video_start_index:] - else: - raise ValueError( - f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got " - f"{video_start_index}" - ) - assert video_embeds is None, "not support video now" + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + combined_embeddings[vision_mask] = vision_embeds + combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() + + if combined_embeddings is not None and cp_size > 1 and packed_seq_params is None: + combined_embeddings = split_data_cp_rank(combined_embeddings, cp_size, 0) + if packed_seq_params is not None: + assert attention_mask is not None, ( + "attention_mask is required for compute position and split by cp and sp" + ) + input_ids_thd, _ = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True) + _, _, vision_mask_thd = reorganize_inputs( + input_ids=input_ids_thd, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_input_mask=image_input_mask, + video_input_mask=video_input_mask, + image_token_id=self.image_token_id, + video_token_id=self.video_token_id, + square_merge_size=self.square_merge_size, + ) + + if deepstack_feature_lists is not None: + tmp_embeddings = torch.zeros_like(combined_embeddings.transpose(0, 1)) + new_deepstack_feature_lists = [] + for deepstack_visual_embed in deepstack_feature_lists: + tmp_embeddings[vision_mask] = deepstack_visual_embed + tmp_embeddings_thd = preprocess_packed_seqs( + tmp_embeddings.contiguous(), + attention_mask, + pre_process=True, + )[0] + new_deepstack_feature_lists.append(tmp_embeddings_thd[vision_mask_thd].contiguous()) + + deepstack_feature_lists = new_deepstack_feature_lists + + vision_mask = vision_mask_thd + combined_embeddings_thd = ( + preprocess_packed_seqs( + combined_embeddings.transpose(0, 1).contiguous(), + attention_mask, + pre_process=True, + )[0] + .transpose(0, 1) + .contiguous() + ) + combined_embeddings = combined_embeddings_thd - if image_embeds is not None: - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() - combined_embeddings[image_mask] = image_embeds - combined_embeddings = combined_embeddings.transpose(0, 1).contiguous() if self.config.sequence_parallel: combined_embeddings = tensor_parallel.scatter_to_sequence_parallel_region(combined_embeddings) combined_embeddings = combined_embeddings.contiguous() + else: combined_embeddings = None + nvtx_range_pop(suffix="forward_pre_process") + + nvtx_range_push(suffix="forward_language_module") + visual_pos_masks = vision_mask + deepstack_visual_embeds = deepstack_feature_lists + if self.config.sequence_parallel or cp_size > 1: + if packed_seq_params is None: # BSHD + visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( + visual_pos_masks, + deepstack_visual_embeds, + tp_size=self.pg_collection.tp.size(), + tp_rank=self.pg_collection.tp.rank(), + cp_size=cp_size, + cp_rank=self.pg_collection.cp.rank(), + ) + elif self.config.sequence_parallel: # THD and SP + visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( + visual_pos_masks, + deepstack_visual_embeds, + tp_size=self.pg_collection.tp.size(), + tp_rank=self.pg_collection.tp.rank(), + cp_size=1, + cp_rank=0, + ) if position_ids is None: + # BSHD position_ids, _ = get_rope_index( self.config.spatial_merge_size, self.image_token_id, @@ -281,20 +501,16 @@ def forward( image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask, - packed_seq_params=packed_seq_params, - ) - - visual_pos_masks = image_mask - deepstack_visual_embeds = deepstack_feature_lists - if self.config.sequence_parallel: - visual_pos_masks, deepstack_visual_embeds = split_deepstack_embs( - visual_pos_masks, - deepstack_visual_embeds, - tp_size=mpu.get_tensor_model_parallel_world_size(), - tp_rank=mpu.get_tensor_model_parallel_rank(), - ) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_push("Qwen3VLModel.forward.language_model") + ) # [3*b*s] + if packed_seq_params is not None: + # convert position_ids to THD format + position_ids = ( + preprocess_packed_seqs(position_ids.permute(1, 2, 0), attention_mask, pre_process=True)[0] + .permute(2, 0, 1) + .contiguous() + ) + attention_mask = None + self.language_model.rotary_pos_emb.is_thd_format = True output = self.language_model( input_ids=None, position_ids=position_ids, # None in encoder @@ -306,7 +522,8 @@ def forward( visual_pos_masks=visual_pos_masks, deepstack_visual_embeds=deepstack_visual_embeds, **(extra_block_kwargs or {}), + **kwargs, ) - torch.cuda.nvtx.range_pop() + nvtx_range_pop(suffix="forward_language_module") return output diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py index bc76415c74..10f75ec4bf 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py @@ -16,14 +16,74 @@ from typing import List, Optional import torch +import torch.nn as nn from megatron.core.packed_seq_params import PackedSeqParams from torch import Tensor -from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextRotaryEmbedding -from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextRotaryEmbedding +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig +from megatron.core.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, + get_pos_emb_on_this_cp_rank, +) -class Qwen3VLMoETextRotaryEmbedding(Qwen3VLMoeTextRotaryEmbedding): - """Qwen3-VL MoE text rotary position embedding.""" + +class Qwen3VLMultimodalRotaryEmbedding(nn.Module): + """Multimodal Rotary Embedding for language model. + only support for qwen3vl + + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position + embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE + for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + """ + + def __init__( + self, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: Optional[float] = None, + rotary_base: int = 10000, + cp_group: torch.distributed.ProcessGroup = None, + ) -> None: + super().__init__() + + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.rotary_interleaved = rotary_interleaved + assert not self.rotary_interleaved, "only support qwen3vl" + + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.inv_freq = 1.0 / ( + rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim) + ) + self.is_thd_format = False # if is thd format, we do not need to split the rotary_pos_emb along CP + self.cp_group = cp_group + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t def forward( self, @@ -38,56 +98,210 @@ def forward( position_ids (torch.Tensor): A postion_id tensor with shape [3, batchsize, seqlens] mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. - + packed_seq_params (PackedSeqParams, optional): Packed sequence params. Defaults to None. Returns: - Tensor: Raw frequency embeddings for Megatron Core (shape: [seq_length, bs, 1, dim]). - Megatron Core will compute cos/sin internally and apply attention_scaling. + Tensor: Embeddings after applying RoPE. """ # Use fp32 for position indices to avoid precision loss when inv_freq is bf16. seq = position_ids.to(device=self.inv_freq.device, dtype=torch.float32) - # if self.seq_len_interpolation_factor is not None: - # seq *= 1 / self.seq_len_interpolation_factor + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor # shape (3, bs, dim, 1) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, seq.shape[1], -1, 1) + inv_freq_expanded = self.inv_freq[None, None, :, None].expand(3, seq.shape[1], -1, 1) # shape (3, bs, 1, seq_length) seq_expanded = seq[:, :, None, :].float() # shape (3, bs, seq_length, dim) freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, mrope_section) emb = torch.cat((freqs, freqs), dim=-1) + + # shape (seq_length, bs, 1, 2 * dim) emb = emb[..., None, :].transpose(0, 1).contiguous() - _ = packed_seq_params # packed sequences not supported yet + if self.cp_group.size() > 1 and not self.is_thd_format: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group) return emb -class Qwen3VLTextRotaryEmbedding(Qwen3VLTextRotaryEmbedding): - """Qwen3-VL text rotary position embedding for non-MoE models.""" +# Slightly modified from Qwen3VLModel.get_rope_index +def get_rope_index( + spatial_merge_size: int, + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" - def forward( - self, - position_ids: torch.Tensor, - mrope_section: List[int] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - **kwargs, - ) -> Tensor: - """Forward pass for non-MoE Qwen3-VL RoPE. + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 - Args: - position_ids: Position IDs tensor - mrope_section: Optional mrope section (if not provided, uses self.mrope_section) - """ - if mrope_section is None: - mrope_section = self.mrope_section - - if position_ids.ndim == 2: - position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) - inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) - position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) - freqs = self.apply_interleaved_mrope(freqs, mrope_section) - emb = torch.cat((freqs, freqs), dim=-1) - emb = emb[..., None, :].transpose(0, 1).contiguous() - _ = packed_seq_params # packed sequences not supported yet - return emb + if packed_seq_params is not None and attention_mask is None and input_ids is not None: + # Build an attention mask from packed sequence metadata when one is not provided. + # cu_seqlens_q entries are cumulative lengths; their diffs give per-sample lengths. + cu_seqlens = packed_seq_params.cu_seqlens_q + if cu_seqlens is not None and cu_seqlens.numel() >= 2: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + attention_mask = torch.zeros_like(input_ids, dtype=input_ids.dtype) + max_len = attention_mask.shape[1] + for i, seq_len in enumerate(seq_lens.tolist()): + valid = min(int(seq_len), max_len) + attention_mask[i, :valid] = 1 + else: + # Fallback to a dense mask if packed metadata is missing. + attention_mask = torch.ones_like(input_ids) + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + +def apply_rotary_pos_emb_thd_absolute( + t: Tensor, cu_seqlens: Tensor, freqs: Tensor, rotary_interleaved: bool = False +) -> Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +def apply_rotary_pos_emb_absolute( + t: Tensor, + freqs: Tensor, + config: Qwen3VLTransformerConfig, + cu_seqlens: Optional[Tensor] = None, +): + """ + Reroute to the appropriate apply_rotary_pos_emb function depending on + bshd (conventional) / thd (packed seq) format + + In Qwen3-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] + """ + assert not config.apply_rope_fusion + orig_t_dtype = t.dtype + if config.apply_rotary_pos_emb_in_fp32: + t = t.float() + + if cu_seqlens is None: + result = _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: + result = apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) + + if config.apply_rotary_pos_emb_in_fp32: + result = result.to(orig_t_dtype) + + return result diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py index c3811cdcac..a3700d7803 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/text_model.py @@ -24,14 +24,12 @@ from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import deprecate_inference_params from torch import Tensor -from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import ( - Qwen3VLMoETextRotaryEmbedding, - Qwen3VLTextRotaryEmbedding, -) +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import Qwen3VLMultimodalRotaryEmbedding from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block import Qwen3VLTransformerBlock from megatron.bridge.models.transformer_config import TransformerConfig @@ -59,6 +57,7 @@ def __init__( seq_len_interpolation_factor: Optional[float] = None, mtp_block_spec: Optional[ModuleSpec] = None, vp_stage: Optional[int] = None, + pg_collection: ProcessGroupCollection = None, ) -> None: super().__init__( config=config, @@ -79,17 +78,18 @@ def __init__( seq_len_interpolation_factor=seq_len_interpolation_factor, mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, + pg_collection=pg_collection, ) - is_moe = ( - hasattr(config, "num_moe_experts") and config.num_moe_experts is not None and config.num_moe_experts > 0 + # rebuild rope + self.rotary_pos_emb = Qwen3VLMultimodalRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + cp_group=pg_collection.cp, ) - - if is_moe: - self.rotary_pos_emb = Qwen3VLMoETextRotaryEmbedding(config.hf_text_config) - else: - self.rotary_pos_emb = Qwen3VLTextRotaryEmbedding(config.hf_text_config) - self.mrope_section = self.config.mrope_section assert self.mrope_section is not None, ( "mrope require mrope_section setting, but we got None from TransformerConfig" @@ -102,6 +102,7 @@ def __init__( pre_process=self.pre_process, post_process=self.post_process, vp_stage=vp_stage, + pg_collection=pg_collection, ) def forward( diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py index f0aa362aff..fa2838a530 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_block.py @@ -22,14 +22,19 @@ from typing import Optional, Union import torch -from megatron.core import parallel_state, tensor_parallel +from torch import nn +from megatron.core import tensor_parallel from megatron.core.enums import Fp8Recipe from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor from torch import Tensor +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset try: @@ -43,6 +48,418 @@ if HAVE_TE: from megatron.core.extensions.transformer_engine import te_checkpoint +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import Qwen3VLVisionPatchMerger +from megatron.core.process_groups_config import ProcessGroupCollection + + +class Qwen3VLVisionTransformerBlock(TransformerBlock): + def __init__( + self, + config: Qwen3VLTransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + post_layer_norm: bool = True, + pre_process: bool = True, + post_process: bool = True, + # model_comm_pgs: ModelCommProcessGroups = None, + vp_stage: Optional[int] = None, + patch_merger_spec: ModuleSpec = None, + pg_collection: ProcessGroupCollection = None, + ): + assert post_process and pre_process, "not support pp for deepstack_merger_list" + super().__init__( + config=config, + spec=spec, + post_layer_norm=post_layer_norm, + pre_process=pre_process, + post_process=post_process, + # model_comm_pgs=model_comm_pgs, + vp_stage=vp_stage, + pg_collection=pg_collection, + ) + self.pg_collection = pg_collection + self.cp_group = pg_collection.cp + self.tp_group = pg_collection.tp + self.pp_group = pg_collection.pp + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + config, + patch_merger_spec, + use_postshuffle_norm=True, + ) + for _ in range(len(config.deepstack_visual_indexes)) + ] + ) + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_fp8_context: bool, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb): + deepstack_feature_lists = [] + for index in range(start, end): + layer = self._get_layer(index) + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + # Check if layer will use TE CUDA graph replay - if so, don't pass + # packed_seq_params since CUDA graph only accepts tensor inputs. + # Use layer.config (not self.config) because the layer's config is what + # determines if _should_call_te_cudagraph returns True. + layer_uses_te_cudagraph = ( + hasattr(layer, 'cuda_graphs') + and layer.cuda_graphs + and layer.training + and hasattr(layer, 'config') + and getattr(layer.config, 'cuda_graph_impl', 'none') == "transformer_engine" + ) + layer_packed_seq_params = None if layer_uses_te_cudagraph else packed_seq_params + + with inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=layer_packed_seq_params, + ) + + l_no = layer.layer_number - 1 + if l_no in self.deepstack_visual_indexes: + deepstack_idx = self.deepstack_visual_indexes.index(l_no) + deepstack_feature = self.deepstack_merger_list[deepstack_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + return hidden_states, deepstack_feature_lists, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + deepstack_feature_lists = [] + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, layer_deepstack_feature_lists, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + layer_idx += self.config.recompute_num_layers + deepstack_feature_lists.extend(layer_deepstack_feature_lists) + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, layer_deepstack_feature_lists, context = checkpoint_handler( + custom(layer_idx, layer_idx + 1) + ) + else: + hidden_states, layer_deepstack_feature_lists, context = custom(layer_idx, layer_idx + 1)( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + deepstack_feature_lists.extend(layer_deepstack_feature_lists) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states, deepstack_feature_lists + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states, deepstack_feature_lists = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_fp8_context=use_inner_fp8_context, + ) + else: + deepstack_feature_lists = [] + for l_no, layer in enumerate(self.layers): + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + assert l_no == layer.layer_number - 1 + with self.offload_context, inner_fp8_context: + # Check if layer will use TE CUDA graph replay - if so, don't pass + # packed_seq_params since CUDA graph only accepts tensor inputs. + # Use layer.config (not self.config) because the layer's config is what + # determines if _should_call_te_cudagraph returns True. + layer_uses_te_cudagraph = ( + hasattr(layer, 'cuda_graphs') + and layer.cuda_graphs + and layer.training + and hasattr(layer, 'config') + and getattr(layer.config, 'cuda_graph_impl', 'none') == "transformer_engine" + ) + layer_packed_seq_params = None if layer_uses_te_cudagraph else packed_seq_params + + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=layer_packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if l_no in self.deepstack_visual_indexes: + deepstack_idx = self.deepstack_visual_indexes.index(l_no) + deepstack_feature = self.deepstack_merger_list[deepstack_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states, deepstack_feature_lists + + def sharded_state_dict( + self, prefix: str = "", sharded_offsets: tuple = (), metadata: dict = None + ) -> ShardedStateDict: + """ + Generate a sharded state dictionary for the transformer block. + + Args: + prefix (str, optional): Prefix to be added to all keys in the state dict. + Defaults to an empty string. + sharded_offsets (tuple, optional): Tuple of sharding offsets. + metadata (dict, optional): Additional metadata for sharding. + Can specify if layers are non-homogeneous. Defaults to None. + + Returns: + ShardedStateDict: A dictionary containing the sharded state of the model. + """ + assert not sharded_offsets, "Unexpected sharded offsets" + non_homogeneous_layers = metadata is not None and metadata.get("non_homogeneous_layers", False) + if self.config.hetereogenous_dist_checkpoint: + non_homogeneous_layers = True + + if isinstance(self.config.moe_layer_freq, int): + if self.config.moe_layer_freq > 1: + non_homogeneous_layers = True + elif isinstance(self.config.moe_layer_freq, list): + non_homogeneous_layers = True + + if self.config.heterogeneous_block_specs: + non_homogeneous_layers = True + + sharded_state_dict = {} + + layer_prefix = f"{prefix}layers." + num_layers = self.config.num_layers + for layer in self.layers: + offset = get_transformer_layer_offset(self.config, self.vp_stage, pp_rank=self.pp_group.rank()) + + global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1 + state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." # module list index in TransformerBlock # pylint: disable=line-too-long + if non_homogeneous_layers: + sharded_prefix = f"{layer_prefix}{global_layer_offset}." + sharded_pp_offset = [] + else: + sharded_prefix = layer_prefix + sharded_pp_offset = [(0, global_layer_offset, num_layers)] # PP sharding offset for ShardedTensors + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + len_deepstack = len(self.deepstack_merger_list) + deepstack_prefix = f"{prefix}deepstack_merger_list." + for global_layer_offset, layer in enumerate(self.deepstack_merger_list): + state_dict_prefix = f"{deepstack_prefix}{global_layer_offset}." # module list index in TransformerBlock # pylint: disable=line-too-long + if non_homogeneous_layers: + sharded_prefix = f"{deepstack_prefix}{global_layer_offset}." + sharded_pp_offset = [] + else: + sharded_prefix = deepstack_prefix + sharded_pp_offset = [(0, global_layer_offset, len_deepstack)] # PP sharding offset for ShardedTensors + layer_sharded_state_dict = layer.sharded_state_dict(state_dict_prefix, sharded_pp_offset, metadata) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + + sharded_state_dict.update(layer_sharded_state_dict) + + # Add modules other than self.layers + for name, module in self.named_children(): + if (not module is self.layers) and (not module is self.deepstack_merger_list): + sharded_state_dict.update( + sharded_state_dict_default(module, f"{prefix}{name}.", sharded_offsets, metadata) + ) + + return sharded_state_dict + class Qwen3VLTransformerBlock(TransformerBlock): """Transformer class.""" @@ -71,8 +488,9 @@ def custom_forward( context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds, + *deepstack_visual_embeds_args, ): + deepstack_visual_embeds = list(deepstack_visual_embeds_args) if deepstack_visual_embeds_args else None for index in range(start, end): layer = self._get_layer(index) inner_fp8_context = ( @@ -104,6 +522,8 @@ def custom_forward( return custom_forward + deepstack_visual_embeds_tuple = tuple(deepstack_visual_embeds) if deepstack_visual_embeds else () + def checkpoint_handler(forward_func): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" if self.config.fp8: @@ -111,14 +531,14 @@ def checkpoint_handler(forward_func): forward_func, self.config.distribute_saved_activations, tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), + self.pg_collection.tp, hidden_states, attention_mask, context, context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds, + *deepstack_visual_embeds_tuple, ) else: return tensor_parallel.checkpoint( @@ -130,7 +550,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds, + *deepstack_visual_embeds_tuple, ) if self.config.recompute_method == "uniform": @@ -169,7 +589,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds, + *deepstack_visual_embeds_tuple, ) else: raise ValueError("Invalid activation recompute method.") diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py index 8d9284c2ef..b8c8a1903e 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/transformer_config.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from copy import deepcopy from dataclasses import dataclass, field +from functools import partial from typing import List, Optional +import torch +import torch.nn.functional as F from megatron.core.transformer.transformer_config import TransformerConfig from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig @@ -49,3 +52,83 @@ class Qwen3VLTransformerConfig(TransformerConfig): video_token_id: int = 151656 vision_start_token_id: int = 151652 hf_text_config: Optional[Qwen3VLTextConfig] = None + + # Vision encoder CUDA graph settings + # Maximum sequence length for vision encoder CUDA graphs. + # Set this to accommodate the largest expected vision input. + # If None, will be calculated from num_position_embeddings / spatial_merge_size^2 + max_vision_cuda_graph_seq_length: Optional[int] = None + + +def get_vision_model_config(config: Qwen3VLTransformerConfig, hf_config): + + # Set vision encoder parameters from HF config + # CRITICAL: num_layers must be set to vision depth, not inherited from language model + config.num_layers = hf_config.depth + config.hidden_size = hf_config.hidden_size + config.num_attention_heads = hf_config.num_heads + config.ffn_hidden_size = hf_config.intermediate_size + config.add_bias_linear = True + config.add_qkv_bias = True + + config.num_moe_experts = None + config.expert_model_parallel_size = 1 + config.moe_ffn_hidden_size = None + + config.hidden_dropout = 0.0 + config.attention_dropout = 0.0 + config.layernorm_epsilon = 1e-6 + config.apply_rotary_pos_emb_in_fp32 = True + + config.patch_size = hf_config.patch_size + config.temporal_patch_size = hf_config.temporal_patch_size + config.in_channels = hf_config.in_channels + config.spatial_merge_size = hf_config.spatial_merge_size + config.num_position_embeddings = hf_config.num_position_embeddings + config.out_hidden_size = hf_config.out_hidden_size + config.deepstack_visual_indexes = deepcopy(hf_config.deepstack_visual_indexes) + + config.gated_linear_unit = False # no gated + config.activation_func = partial(F.gelu, approximate="tanh") # hidden_act + config.kv_channels = config.hidden_size // config.num_attention_heads + config.num_query_groups = config.num_attention_heads # no GQA + config.layernorm_zero_centered_gamma = False # False + config.apply_query_key_layer_scaling = False # factor=math.sqrt(head_dim) + config.bias_activation_fusion = False # no swiglu, set false + config.bias_dropout_fusion = False # no dropout, set false + config.attention_softmax_in_fp32 = True # use True + config.normalization = "LayerNorm" + + config.tp_comm_overlap = False + config.sequence_parallel = False + config.context_parallel_size = 1 + config.pipeline_model_parallel_size = 1 + config.num_layers_in_first_pipeline_stage = None + config.num_layers_in_last_pipeline_stage = None + config.virtual_pipeline_model_parallel_size = None + config.pipeline_model_parallel_layout = None + config.account_for_embedding_in_pipeline_split = None + config.account_for_loss_in_pipeline_split = None + # encoder does not support apply_rope_fusion currently. + config.apply_rope_fusion = False + + # Vision encoder CUDA graph settings + # Check if the input config has vision-specific CUDA graph settings (from provider) + # If so, use them; otherwise default to "none" for backward compatibility + if hasattr(config, 'vision_cuda_graph_impl') and config.vision_cuda_graph_impl != "none": + config.cuda_graph_impl = config.vision_cuda_graph_impl + if hasattr(config, 'vision_cuda_graph_scope') and config.vision_cuda_graph_scope: + # Convert string scope list to CudaGraphScope enums if needed + from megatron.core.transformer.cuda_graphs import CudaGraphScope + scope_list = config.vision_cuda_graph_scope + if scope_list and isinstance(scope_list[0], str): + config.cuda_graph_scope = [CudaGraphScope[scope] for scope in scope_list] + else: + config.cuda_graph_scope = scope_list + else: + config.cuda_graph_scope = [] + else: + config.cuda_graph_impl = "none" + config.cuda_graph_scope = [] + + return config diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py index b172a2f07d..657da09289 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/utils.py @@ -13,12 +13,169 @@ # limitations under the License. -from typing import Optional +from dataclasses import dataclass +from typing import Optional, Union import torch +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from torch import nn +from megatron.core import mpu +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig from megatron.core.packed_seq_params import PackedSeqParams +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +class Qwen3VLVisionPatchEmbed(nn.Module): + def __init__( + self, + config: Qwen3VLTransformerConfig, + ) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d( + self.in_channels, + self.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +# copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +class Qwen3VLVisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int) -> torch.Tensor: + if not hasattr(self, "inv_freq"): + inv_freq = 1.0 / ( + self.theta + ** ( + torch.arange( + 0, + self.dim, + 2, + dtype=torch.float, + device=torch.cuda.current_device(), + ) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +@dataclass +class PatchMergerSubmodules: + patch_norm: Union[ModuleSpec, type] = None + linear_fc1: Union[ModuleSpec, type] = None + linear_fc2: Union[ModuleSpec, type] = None + + +class Qwen3VLVisionPatchMerger(MegatronModule): + def __init__( + self, + config: Qwen3VLTransformerConfig, + submodules: PatchMergerSubmodules, + use_postshuffle_norm=False, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + super().__init__(config=config) + + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.input_size = config.hidden_size + if self.use_postshuffle_norm: + self.input_size = self.hidden_size + self.tp_group = tp_group + + self.patch_norm = build_module( + submodules.patch_norm, + config=self.config, + hidden_size=self.input_size, + eps=self.config.layernorm_epsilon, + ) + + self.linear_fc1 = build_module( + submodules.linear_fc1, + self.hidden_size, + self.hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="patch_fc1", + tp_group=tp_group, + ) + + self.activation_func = self.config.activation_func + + self.linear_fc2 = build_module( + submodules.linear_fc2, + self.hidden_size, + self.config.out_hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="patch_fc1", + tp_group=tp_group, + ) + + def forward(self, hidden_states): + if self.use_postshuffle_norm: + hidden_states = hidden_states.view(-1, self.hidden_size) + hidden_states = self.patch_norm(hidden_states) + hidden_states = hidden_states.view(-1, self.hidden_size) + + hidden_states, _ = self.linear_fc1(hidden_states) + hidden_states = self.activation_func(hidden_states) + output, _ = self.linear_fc2(hidden_states) + + return output + + +def split_part_by_cp_tp(cp_size, cp_rank, tp_size, tp_rank, split_size): + part_list = list(range(split_size)) + + cp_rank2 = 2 * cp_size - cp_rank - 1 + cp_part_list = ( + part_list[cp_rank * tp_size : (cp_rank + 1) * tp_size] + + part_list[cp_rank2 * tp_size : (cp_rank2 + 1) * tp_size] + ) + + assert len(cp_part_list) % tp_size == 0 + echo_tp_len = len(cp_part_list) // tp_size + cp_tp_part_list = cp_part_list[tp_rank * echo_tp_len : (tp_rank + 1) * echo_tp_len] + return cp_tp_part_list + + def split_deepstack_embs( visual_pos_masks: torch.Tensor, deepstack_visual_embeds: list[torch.Tensor], @@ -27,180 +184,555 @@ def split_deepstack_embs( cp_size: int = 1, cp_rank: int = 0, ): - """Split deepstack visual embeddings for tensor and context parallelism. - - Args: - visual_pos_masks: Visual position masks tensor - deepstack_visual_embeds: List of deepstack visual embeddings - tp_size: Tensor parallel size (default: 1) - tp_rank: Tensor parallel rank (default: 0) - cp_size: Context parallel size (default: 1) - cp_rank: Context parallel rank (default: 0) - - Returns: - Split visual embeddings based on parallelism configuration - """ - # first split by cp (zigzag) - assert cp_size == 1 and cp_rank == 0, "no support cp now" - - # split by tp - if tp_size == 1 or visual_pos_masks is None: + split_size = tp_size + if cp_size > 1: + split_size *= cp_size * 2 + if split_size == 1 or visual_pos_masks is None: return visual_pos_masks, deepstack_visual_embeds assert visual_pos_masks.dim() == 2 + assert visual_pos_masks.shape[-1] % split_size == 0 batch_size = visual_pos_masks.size(0) - visual_pos_masks_list = visual_pos_masks.chunk(tp_size, dim=-1) + + # first split by cp(zigzag), then split by sp + # for example cp=2/tp=4 + # visual_pos_masks will split in 16 part: + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + # first split by cp(zigzag) is: + # cp_rank0: [0, 1, 2, 3, 12, 13, 14, 15] + # cp_rank1: [4, 5, 6, 7, 8, 9, 10, 11] + # then split by sp: + # cp_rank0/tp_rank0 = [0, 1] + # cp_rank0/tp_rank1 = [2, 3] + # ... + # cp_rank1/tp_rank2 = [8, 9] + # cp_rank1/tp_rank3 = [10, 11] + cp_tp_part_list = split_part_by_cp_tp(cp_size, cp_rank, tp_size, tp_rank, split_size) + visual_pos_masks_list = visual_pos_masks.chunk(split_size, dim=-1) embed_lens = [ele.sum(-1) for ele in visual_pos_masks_list] embed_lens = torch.stack(embed_lens, dim=-1) embed_cu_lens = embed_lens.view(-1).cumsum(dim=-1).tolist() + assert len(embed_cu_lens) == split_size * batch_size embed_cu_lens = [0] + embed_cu_lens - tp_slices = [] + cp_tp_slices = [] for i in range(batch_size): - idx = i * tp_size + tp_rank - tp_slices.append(slice(embed_cu_lens[idx], embed_cu_lens[idx + 1])) + for idx in cp_tp_part_list: + idx += i * split_size + cp_tp_slices.append(slice(embed_cu_lens[idx], embed_cu_lens[idx + 1])) deepstack_visual_embeds_ret = [] for deepstack_visual_embed in deepstack_visual_embeds: tmp_slice_tensor = [] - for tp_slice in tp_slices: + for tp_slice in cp_tp_slices: tmp_slice_tensor.append(deepstack_visual_embed[tp_slice]) deepstack_visual_embeds_ret.append(torch.cat(tmp_slice_tensor, dim=0)) - return visual_pos_masks_list[tp_rank], deepstack_visual_embeds_ret + visual_pos_masks_ret = torch.cat([visual_pos_masks_list[i] for i in cp_tp_part_list], dim=-1) + + return visual_pos_masks_ret, deepstack_visual_embeds_ret -def get_rope_index( - spatial_merge_size: int, +def find_vision_id_index( + input_ids: torch.Tensor, image_token_id: int, video_token_id: int, - vision_start_token_id: int, - input_ids: Optional[torch.LongTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - packed_seq_params: Optional[PackedSeqParams] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" - - # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split - if video_grid_thw is not None: - video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) - video_grid_thw[:, 0] = 1 - - if packed_seq_params is not None and attention_mask is None and input_ids is not None: - # Build an attention mask from packed sequence metadata when one is not provided. - # cu_seqlens_q entries are cumulative lengths; their diffs give per-sample lengths. - cu_seqlens = packed_seq_params.cu_seqlens_q - if cu_seqlens is not None and cu_seqlens.numel() >= 2: - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - attention_mask = torch.zeros_like(input_ids, dtype=input_ids.dtype) - max_len = attention_mask.shape[1] - for i, seq_len in enumerate(seq_lens.tolist()): - valid = min(int(seq_len), max_len) - attention_mask[i, :valid] = 1 +): + assert input_ids.dim() == 1, "input_ids should be flaaten" + if input_ids.numel() == 0: + return [] + + device = input_ids.device + dtype = input_ids.dtype + assert dtype in [torch.int, torch.int64] + + # keep the value of image_token_id/video_token_id value, others are -1 + code = torch.where( + (input_ids == image_token_id) | (input_ids == video_token_id), + input_ids, + torch.tensor(-1, device=device, dtype=dtype), + ) + + # find the change idx + first = torch.tensor([True], device=device, dtype=torch.bool) + change = torch.cat([first, code[1:] != code[:-1]]) + change_idx = torch.nonzero(change, as_tuple=False).flatten() + + # only keep the change of image_token_id/video_token_id + keep = code[change_idx] > 0 + starts = change_idx[keep] + + # last change position is input_ids.numel() + next_change = torch.cat( + [ + change_idx[1:], + torch.tensor([input_ids.numel()], device=device, dtype=change_idx.dtype), + ] + ) + ends = next_change[keep] + + vals = code[starts] + starts_cpu = starts.tolist() + ends_cpu = ends.tolist() + vals_cpu = vals.tolist() + return [(int(s), int(e), int(v)) for s, e, v in zip(starts_cpu, ends_cpu, vals_cpu)] + + +def reorganize_inputs( + input_ids: torch.Tensor, + pixel_values: torch.Tensor = None, + pixel_values_videos: torch.Tensor = None, + image_grid_thw: torch.Tensor = None, + video_grid_thw: torch.Tensor = None, + image_input_mask: torch.Tensor = None, + video_input_mask: torch.Tensor = None, + image_token_id: int = 151655, + video_token_id: int = 151656, + square_merge_size: int = 4, +): + if pixel_values is None: + if video_input_mask is None and pixel_values_videos is not None: + video_input_mask = (input_ids == video_token_id).contiguous() + return pixel_values_videos, video_grid_thw, video_input_mask + + if pixel_values_videos is None: + if image_input_mask is None and pixel_values is not None: + image_input_mask = (input_ids == image_token_id).contiguous() + return pixel_values, image_grid_thw, image_input_mask + + image_thw_cpu = image_grid_thw.tolist() + video_thw_cpu = video_grid_thw.tolist() + vision_indexs = find_vision_id_index(input_ids.view(-1), image_token_id, video_token_id) + len_split = sum([thw[0] for thw in image_thw_cpu]) + len_split += sum([thw[0] for thw in video_thw_cpu]) + assert len_split == len(vision_indexs) + + vision_values = [] + vision_grid_thw = [] + idx = 0 + video_idx = 0 + image_idx = 0 + video_seqlen = 0 + image_seqlen = 0 + while idx < len(vision_indexs): + start, end, token_id = vision_indexs[idx] + if token_id == image_token_id: + seqlen = 0 + thw = image_thw_cpu[image_idx] + for i in range(thw[0]): + start, end, token_id = vision_indexs[idx + i] + assert token_id == image_token_id + seqlen += (end - start) * square_merge_size + assert seqlen == thw[0] * thw[1] * thw[2] + vision_values.append(pixel_values[image_seqlen : (image_seqlen + seqlen)]) + vision_grid_thw.append(thw) + + image_idx += 1 + idx += thw[0] + image_seqlen += seqlen + elif token_id == video_token_id: + seqlen = 0 + thw = video_thw_cpu[video_idx] + for i in range(thw[0]): + start, end, token_id = vision_indexs[idx + i] + assert token_id == video_token_id + seqlen += (end - start) * square_merge_size + assert seqlen == thw[0] * thw[1] * thw[2] + vision_values.append(pixel_values_videos[video_seqlen : (video_seqlen + seqlen)]) + vision_grid_thw.append(thw) + + video_idx += 1 + idx += thw[0] + video_seqlen += seqlen else: - # Fallback to a dense mask if packed metadata is missing. - attention_mask = torch.ones_like(input_ids) - - mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): - total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) - position_ids = torch.ones( - 3, - input_ids.shape[0], - input_ids.shape[1], - dtype=input_ids.dtype, - device=input_ids.device, - ) - image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) - for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] - image_nums, video_nums = 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == video_token_id).sum() - input_tokens = input_ids.tolist() - llm_pos_ids_list: list = [] - st = 0 - remain_images, remain_videos = image_nums, video_nums - for _ in range(image_nums + video_nums): - if image_token_id in input_tokens and remain_images > 0: - ed_image = input_tokens.index(image_token_id, st) - else: - ed_image = len(input_tokens) + 1 - if video_token_id in input_tokens and remain_videos > 0: - ed_video = input_tokens.index(video_token_id, st) - else: - ed_video = len(input_tokens) + 1 - if ed_image < ed_video: - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) - image_index += 1 - remain_images -= 1 - ed = ed_image - - else: - t, h, w = ( - video_grid_thw[video_index][0], - video_grid_thw[video_index][1], - video_grid_thw[video_index][2], - ) - video_index += 1 - remain_videos -= 1 - ed = ed_video - llm_grid_t, llm_grid_h, llm_grid_w = ( - t.item(), - h.item() // spatial_merge_size, - w.item() // spatial_merge_size, + assert False, f"should not have {token_id=}" + + if video_input_mask is None: + video_input_mask = input_ids == video_token_id + + if image_input_mask is None: + image_input_mask = input_ids == image_token_id + + vision_values = torch.cat(vision_values) + vision_grid_thw = torch.tensor(vision_grid_thw, device=image_grid_thw.device, dtype=image_grid_thw.dtype) + vision_input_mask = video_input_mask | image_input_mask + + return vision_values, vision_grid_thw, vision_input_mask + + +# reference: megatron/training/utils.py get_batch_on_this_cp_rank +def split_data_cp_rank(val: torch.Tensor, cp_size: int, seq_dim: int, cp_rank: int = None): + assert cp_size > 1 + assert 0 == val.shape[seq_dim] % (2 * cp_size), f"{val.shape=} {cp_size=}" + if cp_rank is None: + cp_rank = mpu.get_context_parallel_rank() + if val is None: + return val + + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + + return val + + +def expand_thw(thw: torch.Tensor) -> torch.Tensor: + assert thw.dim() == 2 + repeats = thw[:, 0].to(torch.long) + assert torch.all(repeats > 0), "thw[:,0] must be > 0" + + idx = torch.arange(thw.size(0), device=thw.device).repeat_interleave(repeats) + out = thw[idx].clone() + out[:, 0] = 1 + return out + + +def collapse_thw(expanded: torch.Tensor) -> torch.Tensor: + assert expanded.dim() == 2 + assert expanded.size(1) >= 2 + if expanded.shape[0] < 2: + return expanded + + # find the diff + other = expanded[:, 1:] + prev = torch.cat([other[:1], other[:-1]], dim=0) + change = (other != prev).any(dim=1) + # the index0 must be now row + change[0] = True + + # find the diff + starts = torch.nonzero(change, as_tuple=False).squeeze(1) + ends = torch.cat([starts[1:], torch.tensor([other.size(0)], device=other.device)]) - 1 + counts = ends - starts + 1 + + rows_other = other[starts] + result_first_col = counts.to(expanded.dtype).unsqueeze(1) + result = torch.cat([result_first_col, rows_other], dim=1) + return result + + +# also can use in qwen2vl/qwen2.5vl +def qwen2vl_pad_and_split( + cp_size: int, + hw_factor: int, + pixel_values: list[torch.Tensor], + image_grid_thws: list[torch.Tensor], +): + assert len(pixel_values) == len(image_grid_thws) + # split the pixel_values + split_pixel_values = [] + split_image_grid_thws = [] + for pixel_value, image_grid_thw in zip(pixel_values, image_grid_thws): + split_image_grid_thw = list(torch.split(image_grid_thw, 1, dim=0)) + split_image_grid_thws.extend(split_image_grid_thw) + slice_begin = 0 + for ele in split_image_grid_thw: + slice_end = slice_begin + ele.prod().item() + split_pixel_values.append(pixel_value[slice_begin:slice_end].clone()) + slice_begin = slice_end + + pixel_values = split_pixel_values + image_grid_thws = split_image_grid_thws + img_num = len(image_grid_thws) + + img_num_per_rank = img_num // cp_size + img_num_remain = img_num % cp_size + cp_img_num = [] + for i in range(cp_size): + cp_img_num.append(img_num_per_rank) + if i < img_num_remain: + cp_img_num[i] += 1 + + img_idx = 0 + new_pixel_values = [] + new_image_grid_thws = [] + images_padded = [] + for i in range(cp_size): + seq_len = 0 + img_begin_idx = img_idx + img_end_idx = img_begin_idx + cp_img_num[i] + img_idx += cp_img_num[i] + + for j in range(img_begin_idx, img_end_idx): + seq_len += pixel_values[j].size(0) + new_pixel_values.append(pixel_values[j]) + new_image_grid_thws.append(image_grid_thws[j]) + + image_padded = 0 != seq_len % hw_factor + if image_padded: + padded_seqlen = (seq_len + hw_factor - 1) // hw_factor * hw_factor - seq_len + assert padded_seqlen > 0 and padded_seqlen % 4 == 0 + new_pixel_values.append( + torch.zeros( + [padded_seqlen, pixel_values[0].size(-1)], + dtype=pixel_values[0].dtype, + device=pixel_values[0].device, ) - text_len = ed - st - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) - t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() - llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) - st = ed + llm_grid_t * llm_grid_h * llm_grid_w - - if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) - mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) - return position_ids, mrope_position_deltas - else: - if attention_mask is not None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + ) + new_image_grid_thws.append( + torch.tensor( + [[1, 2, padded_seqlen // 2]], + dtype=image_grid_thws[0].dtype, + device=image_grid_thws[0].device, + ) + ) + cp_img_num[i] += 1 + images_padded.append(int(image_padded)) + + return new_pixel_values, new_image_grid_thws, cp_img_num, images_padded + + +@torch.no_grad +def qwen3vl_cp_split( + cp_size: int, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, +): + assert cp_size > 1 + if pixel_values is None: + assert image_grid_thw is None + return None, None, None, None + + assert not pixel_values.requires_grad + assert not image_grid_thw.requires_grad + # expand video thw + image_grid_thw = expand_thw(image_grid_thw) + + hw_factor = 4 + new_pixel_values, new_image_grid_thws, cp_img_num, images_padded = qwen2vl_pad_and_split( + cp_size, + hw_factor, + [pixel_values], + [image_grid_thw], + ) + for image_padded in images_padded: + assert not image_padded, "qwen3vl vit not support sp now, no need to paded" + + pixel_values = torch.cat(new_pixel_values, dim=0) + image_grid_thw = torch.cat(new_image_grid_thws, dim=0) + return pixel_values, image_grid_thw, cp_img_num, images_padded + + +def get_vision_cp_data( + vision_data: torch.Tensor, + vision_grid_thw: torch.Tensor, + square_merge_size: int, + cp_img_num: list[int], + images_padded: list[bool], +): + """Get vision data and grid_thw for context parallelism. + Returns: + vision_data (torch.Tensor): Vision data of shape [total_thw_size, n_features]. + vision_grid_thw (torch.Tensor): Vision grid_thw of shape [total_thw_size, 3]. + seqlens_list (list of torch.Tensor): List of seqlens of the vision data in each context parallel rank, + for the all gather after vision encoder. + """ + # we use the context parallelism size and context parallel group of LLM for vision model. + # we only divide the number of images in each context parallel rank. + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + assert cp_size == len(cp_img_num) + + seqlens = torch.repeat_interleave(vision_grid_thw[:, 1] * vision_grid_thw[:, 2], vision_grid_thw[:, 0]) + vision_grid_thw_list = [] + vision_data_list = [] + seqlens_list = [] + img_idx = 0 + for i in range(cp_size): + start_idx = img_idx + end_idx = start_idx + cp_img_num[i] + img_idx += cp_img_num[i] + + vision_grid_thw_list.append(vision_grid_thw[start_idx:end_idx]) + if images_padded[i]: + seqlens_list.append(seqlens[start_idx : end_idx - 1]) else: - position_ids = ( - torch.arange(input_ids.shape[1], device=input_ids.device) - .view(1, 1, -1) - .expand(3, input_ids.shape[0], -1) + seqlens_list.append(seqlens[start_idx:end_idx]) + data_start_idx = seqlens[:start_idx].sum() + data_end_idx = seqlens[:end_idx].sum() + vision_data_list.append(vision_data[data_start_idx:data_end_idx]) + new_vision_grid_thw = vision_grid_thw_list[cp_rank] + new_vision_data = vision_data_list[cp_rank] + new_seqlens_list = [t // square_merge_size for t in seqlens_list] + return new_vision_data, new_vision_grid_thw, new_seqlens_list + + +class AllGatherVisionEmbeddings(torch.autograd.Function): + @staticmethod + def forward(ctx, input, seqlens_on_cp_ranks): + outputs = [] + for i in range(len(seqlens_on_cp_ranks)): + o = torch.zeros( + (seqlens_on_cp_ranks[i].sum(), *input.shape[1:]), + device=input.device, + dtype=input.dtype, + layout=input.layout, ) - mrope_position_deltas = torch.zeros( - [input_ids.shape[0], 1], - device=input_ids.device, - dtype=input_ids.dtype, + outputs.append(o) + torch.distributed.all_gather(outputs, input, group=mpu.get_context_parallel_group()) + cp_rank = mpu.get_context_parallel_rank() + ctx.cp_rank = cp_rank + ctx.save_for_backward(*seqlens_on_cp_ranks) + + output = torch.cat(outputs, dim=0) + return output + + @staticmethod + def backward(ctx, grad_output): + cp_rank = ctx.cp_rank + seqlens_on_cp_ranks = ctx.saved_tensors + start_idx = torch.cat(seqlens_on_cp_ranks[:cp_rank]).sum() if cp_rank != 0 else 0 + end_idx = start_idx + seqlens_on_cp_ranks[cp_rank].sum() + grad_output = grad_output[start_idx:end_idx] + return grad_output, None + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding + cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params + else: + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + + shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] - return position_ids, mrope_position_deltas + return output_new diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py new file mode 100644 index 0000000000..2cc1fb2e9a --- /dev/null +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/vision_model.py @@ -0,0 +1,336 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from megatron.core import InferenceParams +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.utils import get_tensor_model_parallel_group_if_none +from torch import nn +from torch.nn import functional as F + +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_block import Qwen3VLVisionTransformerBlock +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import Qwen3VLTransformerConfig +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import ( + Qwen3VLVisionPatchEmbed, + Qwen3VLVisionPatchMerger, + Qwen3VLVisionRotaryEmbedding, +) + + +class Qwen3VLVisionModel(VisionModule): + """Qwen3 ViT vision model. + + Args: + transformer_config (TransformerConfig): Transformer config. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers. + patch_merger_spec (ModuleSpec): Specifies module to use for transformer layers. + """ + + def __init__( + self, + transformer_config: Qwen3VLTransformerConfig, + transformer_layer_spec: ModuleSpec, + patch_merger_spec: ModuleSpec, + pre_process: bool = True, + post_process: bool = True, + pg_collection: ProcessGroupCollection = None, + ) -> None: + assert post_process and pre_process, "not support pp for deepstack_merger_list" + super().__init__(config=transformer_config) + self.spatial_merge_size = transformer_config.spatial_merge_size + self.patch_size = transformer_config.patch_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + self.pg_collection = pg_collection + self.tp_group = self.pg_collection.tp + + self.patch_embed = Qwen3VLVisionPatchEmbed(transformer_config) + self.pos_embed = nn.Embedding(transformer_config.num_position_embeddings, transformer_config.hidden_size) + self.num_grid_per_side = int(transformer_config.num_position_embeddings**0.5) + + head_dim = transformer_config.hidden_size // transformer_config.num_attention_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + self.model_type = ModelType.encoder_or_decoder + self.pre_process = pre_process + self.post_process = post_process + # Transformer layers. + self.decoder = Qwen3VLVisionTransformerBlock( + config=transformer_config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + patch_merger_spec=patch_merger_spec, + pg_collection=self.pg_collection, + ) + + self.merger = None + if self.post_process: + self.merger = Qwen3VLVisionPatchMerger( + transformer_config, + patch_merger_spec, + use_postshuffle_norm=False, + tp_group=self.tp_group, + ) + + self.input_tensor = None + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + """Sets input tensor to the model. + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + if self.pre_process: # always True + self.input_tensor = input_tensor + else: + raise NotImplementedError() + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor( + weight_list, + dtype=self.pos_embed.weight.dtype, + device=self.pos_embed.weight.device, + ) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def _get_max_vision_seq_length(self) -> int: + """Get the maximum sequence length for vision encoder CUDA graphs.""" + if hasattr(self.config, 'max_vision_cuda_graph_seq_length') and self.config.max_vision_cuda_graph_seq_length: + return self.config.max_vision_cuda_graph_seq_length + # Default: calculate from num_position_embeddings + return self.config.num_position_embeddings // (self.config.spatial_merge_size ** 2) + + def _uses_vision_cuda_graph(self) -> bool: + """Check if vision encoder CUDA graphs are enabled.""" + return ( + hasattr(self.config, 'cuda_graph_impl') + and self.config.cuda_graph_impl == "transformer_engine" + and self.training + ) + + def forward( + self, + hidden_states: Optional[torch.Tensor], + grid_thw: torch.Tensor, + inference_params: Optional[InferenceParams] = None, + extra_block_kwargs: dict = None, + ) -> torch.Tensor: + """Forward function of the Qwen3 Vision Model. This function passes the input tensors + through the embedding layer and then the transformer. + + Args: + x (torch.Tensor): input image/video data of shape [n_tokens, n_dims] + grid_thw (torch.Tensor): the size tensor indicates grid size of each image/frame + packed_seq_params (PackedSeqParams): parameters to build attention mask in the backend + + Returns: + x (torch.Tensor): output after final transformer block of shape [b, s, h]. + """ + assert grid_thw is not None + assert self.input_tensor is None + assert inference_params is None + + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + seq_len, _ = hidden_states.size() + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, 1, 1, -1).repeat(1, 1, 1, 2) + + # Check if we need to pad for CUDA graphs + use_cuda_graph_padding = self._uses_vision_cuda_graph() + original_seq_len = seq_len + + if use_cuda_graph_padding: + max_seq_len = self._get_max_vision_seq_length() + if seq_len > max_seq_len: + raise ValueError( + f"Vision input sequence length ({seq_len}) exceeds max_vision_cuda_graph_seq_length ({max_seq_len}). " + f"Increase max_vision_cuda_graph_seq_length in config or disable vision CUDA graphs." + ) + + if seq_len < max_seq_len: + # Pad hidden_states: [seq_len, hidden_size] -> [max_seq_len, hidden_size] + pad_len = max_seq_len - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len), value=0.0) + + # Pad rotary_pos_emb: [seq_len, 1, 1, head_dim*2] -> [max_seq_len, 1, 1, head_dim*2] + rotary_pos_emb = F.pad(rotary_pos_emb, (0, 0, 0, 0, 0, 0, 0, pad_len), value=0.0) + + seq_len = max_seq_len + + hidden_states = hidden_states[:, None] + + # When using CUDA graphs, we don't pass packed_seq_params (non-tensor) + # Instead, use full attention (which is fine since we pad to fixed size) + if use_cuda_graph_padding: + packed_seq_params = None + # Create causal attention mask for padded input + # For vision encoder, we typically use full attention (no causal mask) + # but we need to mask out the padding positions + if original_seq_len < seq_len: + # Create attention mask: [1, 1, seq_len, seq_len] + # Mask out attention to/from padding positions + attention_mask = torch.ones( + (1, 1, seq_len, seq_len), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + # Mask out padding columns (keys) + attention_mask[:, :, :, original_seq_len:] = 0 + # Mask out padding rows (queries) - these will produce garbage anyway + attention_mask[:, :, original_seq_len:, :] = 0 + # Convert to additive mask (0 -> 0, 1 -> -inf for softmax) + attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min + else: + attention_mask = None + else: + packed_seq_params = self.build_packed_seq_params(grid_thw) + attention_mask = None + + hidden_states, deepstack_feature_lists = self.decoder( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + # Remove padding if we added it + if use_cuda_graph_padding and original_seq_len < seq_len: + hidden_states = hidden_states[:original_seq_len] + # Unpad deepstack features - they go through a merger that reduces by spatial_merge_size^2 + # So their length is seq_len // (spatial_merge_size^2) + original_merged_seq_len = original_seq_len // (self.spatial_merge_size ** 2) + deepstack_feature_lists = [feat[:original_merged_seq_len] for feat in deepstack_feature_lists] + + hidden_states = self.merger(hidden_states) + + # Encodes images into continuous embeddings that can be forwarded to the language model. + split_sizes = (grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() + hidden_states = torch.split(hidden_states, split_sizes) + hidden_states = torch.cat(hidden_states, dim=0) + return hidden_states, deepstack_feature_lists + + def build_packed_seq_params( + self, + grid_thw: Optional[torch.Tensor], + ) -> PackedSeqParams: + # NOTE: each frame is a sequence (rather than each grid) + seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + cu_seqlens = seqlens.cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).int() + + max_seqlen_q = seqlens.max() + return PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format="thd", + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_q, + ) diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py index 35d92c78fd..559e8cc5c1 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py @@ -27,6 +27,7 @@ GatedMLPMapping, QKVMapping, ReplicatedMapping, + ConcatenatedQKVMapping, ) from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel @@ -151,6 +152,35 @@ def mapping_registry(self) -> MegatronMappingRegistry: # QK layernorm weights (Qwen3 specific) "language_model.decoder.layers.*.self_attention.q_layernorm.weight": "model.language_model.layers.*.self_attn.q_norm.weight", "language_model.decoder.layers.*.self_attention.k_layernorm.weight": "model.language_model.layers.*.self_attn.k_norm.weight", + # vision module attn + "vision_model.decoder.layers.*.self_attention.linear_proj.weight": "model.visual.blocks.*.attn.proj.weight", + "vision_model.decoder.layers.*.self_attention.linear_proj.bias": "model.visual.blocks.*.attn.proj.bias", + # vision module mlp + "vision_model.decoder.layers.*.mlp.linear_fc1.weight": "model.visual.blocks.*.mlp.linear_fc1.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.bias": "model.visual.blocks.*.mlp.linear_fc1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc2.weight": "model.visual.blocks.*.mlp.linear_fc2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc2.bias": "model.visual.blocks.*.mlp.linear_fc2.bias", + # vision module norm + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.visual.blocks.*.norm1.weight", + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.visual.blocks.*.norm1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.visual.blocks.*.norm2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.visual.blocks.*.norm2.bias", + # # vision module deepstack merger + "vision_model.decoder.deepstack_merger_list.*.patch_norm.weight": "model.visual.deepstack_merger_list.*.norm.weight", + "vision_model.decoder.deepstack_merger_list.*.patch_norm.bias": "model.visual.deepstack_merger_list.*.norm.bias", + "vision_model.decoder.deepstack_merger_list.*.linear_fc1.weight": "model.visual.deepstack_merger_list.*.linear_fc1.weight", + "vision_model.decoder.deepstack_merger_list.*.linear_fc1.bias": "model.visual.deepstack_merger_list.*.linear_fc1.bias", + "vision_model.decoder.deepstack_merger_list.*.linear_fc2.weight": "model.visual.deepstack_merger_list.*.linear_fc2.weight", + "vision_model.decoder.deepstack_merger_list.*.linear_fc2.bias": "model.visual.deepstack_merger_list.*.linear_fc2.bias", + # vision module merger + "vision_model.merger.patch_norm.**": "model.visual.merger.norm.**", + "vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight", + "vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias", + "vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight", + "vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias", + # These patch_embed are conv, we need to use ReplicatedMapping + # "vision_model.patch_embed.proj.**": "model.visual.patch_embed.proj.**", + # "vision_model.pos_embed.weight": "model.visual.pos_embed.weight", } mapping_list = [] @@ -162,11 +192,22 @@ def mapping_registry(self) -> MegatronMappingRegistry: # Add special mappings that require parameter transformation mapping_list.extend( [ - # Vision model weights are replicated directly - # This handles all vision encoder layers, patch embeddings, mergers, etc. + # QKV mapping for vision model + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", + hf_param="model.visual.blocks.*.attn.qkv.weight", + ), + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", + hf_param="model.visual.blocks.*.attn.qkv.bias", + ), ReplicatedMapping( - megatron_param="vision_model.**", - hf_param="model.visual.**", + megatron_param="vision_model.patch_embed.proj.**", + hf_param="model.visual.patch_embed.proj.**", + ), + ReplicatedMapping( + megatron_param="vision_model.pos_embed.**", + hf_param="model.visual.pos_embed.**", ), # QKV mapping: Combine separate Q, K, V matrices into single QKV matrix QKVMapping( diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py index 9cff84c314..485bf89fb7 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py @@ -25,12 +25,32 @@ from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TENorm, + TERowParallelLinear, +) from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig, Qwen3VLVisionConfig from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig from megatron.bridge.models import Qwen3ModelProvider, Qwen3MoEModelProvider from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.transformer_config import get_vision_model_config +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.utils import PatchMergerSubmodules +from megatron.core.models.vision.vit_layer_specs import ( + get_vit_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.cuda_graphs import CudaGraphScope +from copy import deepcopy + + +def _convert_cuda_graph_scope_to_enum(scope_list: List[str]) -> List[CudaGraphScope]: + """Convert string list to CudaGraphScope enum list.""" + if not scope_list: + return [] + return [CudaGraphScope[scope] for scope in scope_list] + @dataclass class Qwen3VLModelProvider(Qwen3ModelProvider): @@ -106,14 +126,45 @@ class Qwen3VLModelProvider(Qwen3ModelProvider): qk_layernorm: bool = True + bias_activation_fusion: bool = True # Fuse swiglu bias and activation + + # Vision encoder CUDA graph settings + # Set to "transformer_engine" to enable TE CUDA graph for vision encoder + vision_cuda_graph_impl: str = "none" + # CUDA graph scope for vision encoder (e.g., ["attn"] for attention only) + vision_cuda_graph_scope: List[str] = field(default_factory=list) + # Maximum sequence length for vision encoder CUDA graphs (must accommodate largest input) + # If None, calculated from num_position_embeddings / spatial_merge_size^2 + max_vision_cuda_graph_seq_length: Optional[int] = None + def provide(self, pre_process=None, post_process=None, vp_stage=None): """ Provide a Qwen3VL model instance with vision and language components. """ language_transformer_config = self + # Convert language model's cuda_graph_scope from strings to enums if needed + if hasattr(language_transformer_config, 'cuda_graph_scope') and language_transformer_config.cuda_graph_scope: + if isinstance(language_transformer_config.cuda_graph_scope[0], str): + language_transformer_config.cuda_graph_scope = _convert_cuda_graph_scope_to_enum( + language_transformer_config.cuda_graph_scope + ) + hf_vision_config = self.vision_config + vision_transformer_config = get_vision_model_config(deepcopy(language_transformer_config), hf_vision_config) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + # Apply vision encoder CUDA graph settings + vision_transformer_config.cuda_graph_impl = self.vision_cuda_graph_impl + # Convert string scope list to CudaGraphScope enums + vision_transformer_config.cuda_graph_scope = _convert_cuda_graph_scope_to_enum( + self.vision_cuda_graph_scope + ) if self.vision_cuda_graph_scope else [] + # Set max sequence length for vision CUDA graphs + vision_transformer_config.max_vision_cuda_graph_seq_length = self.max_vision_cuda_graph_seq_length + # Spec for the Qwen3VLTransformerLayer language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=None, @@ -121,13 +172,22 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): qk_layernorm=self.qk_layernorm, fp8=False, ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + vision_patch_merger_spec = PatchMergerSubmodules( + patch_norm=TENorm, + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) model = Qwen3VLModel( language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_transformer_layer_spec, vision_transformer_config=hf_vision_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_patch_merger_spec=vision_patch_merger_spec, pre_process=pre_process, post_process=post_process, + pg_collection=self._pg_collection, ) # Apply freeze options if any are enabled for fine-tuning @@ -243,6 +303,15 @@ class Qwen3VLMoEModelProvider(Qwen3MoEModelProvider): freeze_vision_projection: bool = False language_max_sequence_length: int = 2048 + # Vision encoder CUDA graph settings + # Set to "transformer_engine" to enable TE CUDA graph for vision encoder + vision_cuda_graph_impl: str = "none" + # CUDA graph scope for vision encoder (e.g., ["attn"] for attention only) + vision_cuda_graph_scope: List[str] = field(default_factory=list) + # Maximum sequence length for vision encoder CUDA graphs (must accommodate largest input) + # If None, calculated from num_position_embeddings / spatial_merge_size^2 + max_vision_cuda_graph_seq_length: Optional[int] = None + # QK layernorm is already True in Qwen3MoEModelProvider, no need to redefine # These are typically set in the base class but documented here for clarity @@ -255,6 +324,8 @@ class Qwen3VLMoEModelProvider(Qwen3MoEModelProvider): distribute_saved_activations: bool = False # Don't distribute saved activations cp_comm_type: str = "p2p" # Point-to-point communication for context parallel + use_hf_vision_model: bool = False + def finalize(self) -> None: if self.tensor_model_parallel_size > 1: self.sequence_parallel = True @@ -267,9 +338,29 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): """ language_transformer_config = self + # Convert language model's cuda_graph_scope from strings to enums if needed + if hasattr(language_transformer_config, 'cuda_graph_scope') and language_transformer_config.cuda_graph_scope: + if isinstance(language_transformer_config.cuda_graph_scope[0], str): + language_transformer_config.cuda_graph_scope = _convert_cuda_graph_scope_to_enum( + language_transformer_config.cuda_graph_scope + ) + # Create vision transformer config - placeholder for future use # vision_transformer_config = deepcopy(self) - hf_config = self.vision_config + hf_vision_config = self.vision_config + + vision_transformer_config = get_vision_model_config(deepcopy(language_transformer_config), hf_vision_config) + vision_transformer_config.pipeline_model_parallel_size = 1 + vision_transformer_config.first_pipeline_num_layers = None + + # Apply vision encoder CUDA graph settings + vision_transformer_config.cuda_graph_impl = self.vision_cuda_graph_impl + # Convert string scope list to CudaGraphScope enums + vision_transformer_config.cuda_graph_scope = _convert_cuda_graph_scope_to_enum( + self.vision_cuda_graph_scope + ) if self.vision_cuda_graph_scope else [] + # Set max sequence length for vision CUDA graphs + vision_transformer_config.max_vision_cuda_graph_seq_length = self.max_vision_cuda_graph_seq_length language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=self.num_moe_experts, @@ -278,13 +369,23 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): fp8=False, ) + vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() + vision_patch_merger_spec = PatchMergerSubmodules( + patch_norm=TENorm, + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ) + # reuse Qwen3VLModel for MoE model but replace the language model with MoE language model model = Qwen3VLModel( language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_transformer_layer_spec, - vision_transformer_config=hf_config, + vision_transformer_config=hf_vision_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + vision_patch_merger_spec=vision_patch_merger_spec, pre_process=pre_process, post_process=post_process, + pg_collection=self._pg_collection, ) # Apply freeze options if any are enabled for fine-tuning diff --git a/src/megatron/bridge/training/qwen3vl_step.py b/src/megatron/bridge/training/qwen3vl_step.py new file mode 100644 index 0000000000..c233d11c02 --- /dev/null +++ b/src/megatron/bridge/training/qwen3vl_step.py @@ -0,0 +1,369 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import logging +from functools import partial +from typing import Any, Iterable + +import torch +from megatron.core.models.gpt import GPTModel +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config, get_thd_batch_on_this_cp_rank + +from megatron.bridge.training.config import ConfigContainer +from megatron.bridge.training.losses import ( + create_masked_next_token_loss_function as _create_loss_function, +) +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.utils.padding_utils import ( + pad_or_truncate_2d_to_len, + pad_or_truncate_attn_to_len, + pad_or_truncate_pos_to_len, +) +from megatron.bridge.training.utils.pg_utils import get_pg_collection + + +logger = logging.getLogger(__name__) + + +def get_batch_from_iterator( + data_iterator: Iterable, + use_mtp: bool = False, + skip_getting_attention_mask_from_dataset: bool = True, + *, + is_first_pp_stage: bool, + is_last_pp_stage: bool, +) -> dict[str, Any]: + """Get a batch of data from the iterator. + + Args: + data_iterator: The data iterator to get the batch from. + use_mtp: Whether Multi-Token Prediction layers are enabled. + skip_getting_attention_mask_from_dataset: If set, the dataset will pass a None attention mask. + + Returns: + dict[str, torch.Tensor]: A dictionary containing the batch data. + """ + batch = next(data_iterator) + + required_device_keys = set() + required_host_keys = set() + + if not skip_getting_attention_mask_from_dataset: + required_device_keys.add("attention_mask") + + # Instead of raw tensors, expect a single 'visual_inputs' object in batch + required_device_keys.add("visual_inputs") + + if "cu_seqlens" in batch: + required_device_keys.add("cu_seqlens") + required_host_keys.add("cu_seqlens_argmin") + required_host_keys.add("max_seqlen") + + required_device_keys.update(("tokens", "input_ids", "position_ids")) + if is_last_pp_stage: + required_device_keys.update(("labels", "loss_mask")) + + _batch_required_keys = {} + for key, val in batch.items(): + if key in required_device_keys: + if key == "visual_inputs": + if val is None: + _batch_required_keys[key] = None + else: + _batch_required_keys[key] = val + # Move all visual inputs contained tensors to CUDA + for k, v in val.__dict__.items(): + _batch_required_keys[key].__dict__[k] = v.cuda(non_blocking=True) if v is not None else None + else: + _batch_required_keys[key] = val.cuda(non_blocking=True) if val is not None else None + elif key in required_host_keys: + _batch_required_keys[key] = val.cpu() if val is not None else None + else: + _batch_required_keys[key] = None + + return _batch_required_keys + + +def get_batch( + data_iterator: Iterable, + cfg: ConfigContainer, + use_mtp: bool = False, + *, + is_first_pp_stage: bool, + is_last_pp_stage: bool, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Any, +]: + """Generate a batch. + + Args: + data_iterator: Input data iterator + cfg: Configuration container + use_mtp: Whether Multi-Token Prediction layers are enabled + is_first_pp_stage: Whether the current stage is the first stage + is_last_pp_stage: Whether the current stage is the last stage + Returns: + TODO: add description + """ + # All PP stages load from iterator to get input_ids and visual grid info + # This allows each stage to compute MRoPE position_ids locally without broadcasting + batch = get_batch_from_iterator( + data_iterator, + use_mtp, + getattr(cfg.dataset, "skip_getting_attention_mask_from_dataset", True), + is_first_pp_stage=is_first_pp_stage, + is_last_pp_stage=is_last_pp_stage, + ) + + if "visual_inputs" in batch: + # convert visual_inputs to multi_modal_inputs which is a dict contains "pixel_values" and "image_grid_thw" + # TODO(jinliangl): add video support + multi_modal_inputs = batch.get("visual_inputs").normalized_for_model() + else: + multi_modal_inputs = {} + + # return naive batch and don't do any padding or cp slicing + return ( + batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids"), + batch.get("labels"), + batch.get("loss_mask"), + batch.get("attention_mask"), + batch.get("position_ids"), + multi_modal_inputs, + ) + + +def pack_or_pad_batch_sequences( + tokens: torch.Tensor, + labels: torch.Tensor, + loss_mask: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + this_pg_collection, + use_fp8_padding: bool = False, + data_format: str = "bshd", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackedSeqParams]: + """ + Pad or truncate the batch sequences to the target length, and build packed sequences. + If is_qwen3vl, return bshd tokens for be compatible with qwen3vl model. + Otherwise, return thd tokens and packed sequences. + """ + + batch_size, cur_len = tokens.shape + device = tokens.device + + tp_size = this_pg_collection.tp.size() + cp_size = this_pg_collection.cp.size() + divisible_by = tp_size * cp_size * 2 if cp_size > 1 else tp_size + divisible_by = math.lcm(divisible_by, 16) if use_fp8_padding else divisible_by + + if data_format == "thd": + # build thd sequences with tiny padding + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + seqlens_in_batch_padded = seqlens_in_batch + (divisible_by - seqlens_in_batch % divisible_by) % divisible_by + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=tokens.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + max_seqlen_in_batch_padded = seqlens_in_batch_padded.max().item() + + seqlens_in_batch_cpu = seqlens_in_batch.tolist() + cu_seqlens_padded_cpu = cu_seqlens_padded.tolist() + total_len = cu_seqlens_padded_cpu[-1] + + # Concatenate sequences (remove padding) + packed_tokens = torch.zeros(1, total_len, dtype=tokens.dtype, device=device) + packed_labels = torch.zeros(1, total_len, dtype=labels.dtype, device=device) + packed_loss_mask = torch.zeros(1, total_len, dtype=loss_mask.dtype, device=device) + packed_position_ids = torch.zeros(1, total_len, dtype=position_ids.dtype, device=device) + + for i, seqlen in enumerate(seqlens_in_batch_cpu): + start_idx = cu_seqlens_padded_cpu[i] + packed_tokens[start_idx : start_idx + seqlen] = tokens[i, :seqlen] + packed_labels[start_idx : start_idx + seqlen] = labels[i, :seqlen] + packed_loss_mask[start_idx : start_idx + seqlen] = loss_mask[i, :seqlen] + packed_position_ids[start_idx : start_idx + seqlen] = position_ids[i, :seqlen] + + tokens = packed_tokens + labels = packed_labels + loss_mask = packed_loss_mask + position_ids = packed_position_ids + # attention mask is not used with packed sequences (handled by cu_seqlens) + attention_mask = None + elif data_format == "bshd": + # build bshd sequences with tiny padding to be compatible with qwen3vl model + target_len = math.ceil(cur_len / divisible_by) * divisible_by + tokens = pad_or_truncate_2d_to_len(tokens, target_len=target_len, max_cap=target_len, pad_value=0) + labels = pad_or_truncate_2d_to_len(labels, target_len=target_len, max_cap=target_len, pad_value=-100) + loss_mask = pad_or_truncate_2d_to_len(loss_mask, target_len=target_len, max_cap=target_len, pad_value=0) + attention_mask = pad_or_truncate_attn_to_len(attention_mask, target_len=target_len, max_cap=target_len) + position_ids = pad_or_truncate_pos_to_len(position_ids, target_len=target_len, max_cap=target_len) + + seqlens_in_batch = torch.ones(batch_size, dtype=torch.int32, device=tokens.device) * target_len + seqlens_in_batch_padded = torch.ones(batch_size, dtype=torch.int32, device=tokens.device) * target_len + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=tokens.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=tokens.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + max_seqlen_in_batch = seqlens_in_batch.max().item() + max_seqlen_in_batch_padded = seqlens_in_batch_padded.max().item() + else: + raise ValueError(f"Invalid data format: {data_format}") + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch_padded, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + + return tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params + + +def forward_step( + state: GlobalState, + data_iterator: Iterable, + model: GPTModel, + return_schedule_plan: bool = False, +) -> tuple[torch.Tensor, partial]: + """Forward training step. + + Args: + state: Global state for the run + data_iterator: Input data iterator + model: The GPT Model + return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor + + Returns: + tuple containing the output tensor and the loss function + """ + timers = state.timers + straggler_timer = state.straggler_timer + + this_pg_collection = get_pg_collection(model) + is_first = is_pp_first_stage(this_pg_collection.pp) + is_last = is_pp_last_stage(this_pg_collection.pp) + + is_qwen3vl = "Qwen3-VL" in getattr(state.cfg.model, "hf_model_id", "") + assert is_qwen3vl, "Only Qwen3-VL model is supported" + + config = get_model_config(model) + use_mtp = (getattr(config, "mtp_num_layers", None) or 0) > 0 + + timers("batch-generator", log_level=2).start() + with straggler_timer(bdata=True): + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + multi_modal_inputs, + ) = get_batch(data_iterator, state.cfg, use_mtp, is_first_pp_stage=is_first, is_last_pp_stage=is_last) + timers("batch-generator").stop() + + # To be compatible with qwen3vl, we move the sequence padding and packing to forward_step function. + # Qwen3VL model need the original input and do cp and sp split in model.forward. + pack_sequences_in_batch = getattr(state.cfg.dataset, "pack_sequences_in_batch", False) + if pack_sequences_in_batch: + data_format = "thd" + if is_qwen3vl: + data_format = "bshd" + else: + data_format = "bshd" + + tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = pack_or_pad_batch_sequences( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + this_pg_collection, + use_fp8_padding=True, + data_format=data_format, + ) + forward_args = { + "input_ids": tokens, + "labels": labels, + "loss_mask": loss_mask, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + if data_format == "thd": + forward_args, packed_seq_params = get_thd_batch_on_this_cp_rank( + forward_args, + packed_seq_params.cu_seqlens, + packed_seq_params.cu_seqlens_padded, + packed_seq_params.max_seqlen, + cp_size=this_pg_collection.cp.size(), + cp_rank=this_pg_collection.cp.rank(), + ) + forward_args["packed_seq_params"] = packed_seq_params + elif data_format == "bshd": + original_tokens = tokens.clone() + forward_args = get_batch_on_this_cp_rank(forward_args, cp_group=this_pg_collection.cp) + forward_args["packed_seq_params"] = None + if is_qwen3vl: + forward_args["input_ids"] = original_tokens + # calculate position_ids in model forward + forward_args["position_ids"] = None + if pack_sequences_in_batch: + forward_args["labels"] = forward_args["labels"].reshape(1, -1) + attention_mask = torch.ones(original_tokens.shape[0], original_tokens.shape[1], dtype=torch.bool, device=original_tokens.device) + forward_args["attention_mask"] = attention_mask + # qwen3vl need the original input_ids and position_ids + # use split attention mask for calculate loss + forward_args["packed_seq_params"] = packed_seq_params + else: + raise ValueError(f"Invalid data format: {data_format}") + + # use cp split loss mask for calculate loss + loss_mask = forward_args["loss_mask"] + # follow the design of verl, we put the multi-modal inputs in the forward args + if "pixel_values" in multi_modal_inputs: + forward_args["pixel_values"] = multi_modal_inputs["pixel_values"] + if "image_grid_thw" in multi_modal_inputs: + forward_args["image_grid_thw"] = multi_modal_inputs["image_grid_thw"] + if "pixel_values_videos" in multi_modal_inputs: + forward_args["pixel_values_videos"] = multi_modal_inputs["pixel_values_videos"] + if "video_grid_thw" in multi_modal_inputs: + forward_args["video_grid_thw"] = multi_modal_inputs["video_grid_thw"] + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + with straggler_timer: + if return_schedule_plan: + assert config.overlap_moe_expert_parallel_comm, ( + "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan" + ) + schedule_plan = model.build_schedule_plan( + tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask + ) + loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + return schedule_plan, loss_function + else: + output_tensor = model(**forward_args) + + loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index 5330079637..4968138b99 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -46,7 +46,11 @@ from megatron.core.rerun_state_machine import RerunDataIterator, get_rerun_state_machine from megatron.core.transformer import MegatronModule from megatron.core.transformer.cuda_graphs import TECudaGraphHelper -from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config +from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config, get_attr_wrapped_model +from megatron.core.transformer.vision_cuda_graphs import ( + VisionTECudaGraphHelper, + get_vision_cuda_graph_seq_length, +) from modelopt.torch.distill.plugins.megatron import get_tensor_shapes_adjust_fn_for_distillation from megatron.bridge.training import fault_tolerance @@ -242,8 +246,39 @@ def train( seq_length=config.model.seq_length, micro_batch_size=config.train.micro_batch_size, optimizers=[optimizer], + pg_collection=pg_collection, ) + # Capture Vision Encoder CUDA Graphs (separate from language model). + # Check if vision encoder has CUDA graph enabled + vision_cuda_graph_helper = None + vision_config = getattr(config.model, 'vision_cuda_graph_impl', None) + if vision_config == "transformer_engine": + # Try to get vision config from the model + try: + for model_chunk in model: + unwrapped = get_attr_wrapped_model( + model_chunk, 'vision_model', allow_none=True, return_model_obj=True + ) + if unwrapped is not None and hasattr(unwrapped, 'vision_model') and unwrapped.vision_model is not None: + vision_model_config = unwrapped.vision_model.config + if vision_model_config.cuda_graph_impl == "transformer_engine": + vision_seq_length = get_vision_cuda_graph_seq_length(vision_model_config) + vision_cuda_graph_helper = VisionTECudaGraphHelper( + model=model, + vision_config=vision_model_config, + vision_seq_length=vision_seq_length, + micro_batch_size=config.train.micro_batch_size, + num_microbatches=get_num_microbatches(), + ) + print_rank_0( + f"Vision encoder CUDA graph enabled with seq_length={vision_seq_length}" + ) + break + except Exception as e: + print_rank_0(f"Warning: Failed to initialize vision CUDA graph helper: {e}") + vision_cuda_graph_helper = None + # Track train step elapsed time for throughput logging history_wct = None if config.logger.log_throughput_to_tensorboard: @@ -328,6 +363,16 @@ def train( enable_forward_pre_hook(model) cuda_graph_helper.cuda_graph_set_manual_hooks() + # Capture Vision Encoder CUDA Graphs after warmup (separate from language model). + if ( + vision_cuda_graph_helper is not None + and not vision_cuda_graph_helper.graphs_created() + and global_state.train_state.step - start_iteration == model_config.cuda_graph_warmup_steps + ): + print_rank_0("Capturing vision encoder CUDA graphs...") + vision_cuda_graph_helper.create_cudagraphs() + vision_cuda_graph_helper.cuda_graph_set_manual_hooks() + # Run training step. fault_tolerance.on_training_step_start(global_state) ( @@ -396,6 +441,13 @@ def train( ): assert cuda_graph_helper.graphs_created(), "CUDA Graphs should have been created." cuda_graph_helper.cuda_graph_set_manual_hooks() + # Also set manual hooks for vision encoder CUDA graphs if enabled + if ( + vision_cuda_graph_helper is not None + and model_config.cuda_graph_warmup_steps == 0 + and vision_cuda_graph_helper.graphs_created() + ): + vision_cuda_graph_helper.cuda_graph_set_manual_hooks() global_state.train_state.step += 1 @@ -537,7 +589,7 @@ def train( if should_exit: break - _delete_cuda_graphs(cuda_graph_helper) + _delete_cuda_graphs(cuda_graph_helper, vision_cuda_graph_helper) # Flush TensorBoard, WandB writers and one-logger. writer = global_state.tensorboard_logger @@ -1308,7 +1360,10 @@ def _handle_mxfp8_param_buffer_copy( optim_instance._copy_main_params_to_param_buffer() -def _delete_cuda_graphs(cuda_graph_helper: TECudaGraphHelper): +def _delete_cuda_graphs( + cuda_graph_helper: TECudaGraphHelper, + vision_cuda_graph_helper: Optional[VisionTECudaGraphHelper] = None, +): """ Delete the CUDA graph object as they hold a reference to the some of the nccl buffers, thus blocking the process-destory (torch.dist.destroy_process_group()) at the end of the training loop. @@ -1316,7 +1371,8 @@ def _delete_cuda_graphs(cuda_graph_helper: TECudaGraphHelper): TODO: Move this method to MCore. Args: - cuda_graph_helper: The TECudaGraphHelper object. + cuda_graph_helper: The TECudaGraphHelper object for language model. + vision_cuda_graph_helper: The VisionTECudaGraphHelper object for vision encoder. """ @@ -1335,6 +1391,10 @@ def _delete_cuda_graphs(cuda_graph_helper: TECudaGraphHelper): del cuda_graph del layer.cuda_graphs + # Cleanup vision encoder CUDA graphs + if vision_cuda_graph_helper is not None: + vision_cuda_graph_helper.delete_cuda_graphs() + # Run GC to collect the freshed object gc.collect() diff --git a/src/megatron/bridge/training/utils/packed_seq_utils.py b/src/megatron/bridge/training/utils/packed_seq_utils.py index 98dbd6d5ac..8e265f2403 100644 --- a/src/megatron/bridge/training/utils/packed_seq_utils.py +++ b/src/megatron/bridge/training/utils/packed_seq_utils.py @@ -15,6 +15,7 @@ from __future__ import annotations import torch +import math from megatron.core.packed_seq_params import PackedSeqParams @@ -56,23 +57,176 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams: max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None - # When cu_seqlens_unpadded is present (pad_seq_to_mult > 1), pass both unpadded and padded - # for proper THD CP support. Otherwise, just use cu_seqlens_padded to avoid slower TE kernel. - if cu_seqlens_unpadded is not None: - return PackedSeqParams( - cu_seqlens_q=cu_seqlens_unpadded, - cu_seqlens_kv=cu_seqlens_unpadded, - cu_seqlens_q_padded=cu_seqlens_padded, - cu_seqlens_kv_padded=cu_seqlens_padded, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) + # When unpadded lengths are available, use them for q/kv and keep padded + # offsets for kernels that support padded variants. + return PackedSeqParams( + cu_seqlens_q=cu_seqlens_unpadded if cu_seqlens_unpadded is not None else cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_unpadded if cu_seqlens_unpadded is not None else cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format="thd", + ) + + +# Copied from verl/verl/models/mcore/util.py +from megatron.core import mpu + + +def preprocess_packed_seqs( + input_ids: torch.Tensor, attention_mask: torch.Tensor, pre_process: bool = True, use_fp8_padding=False +) -> tuple[torch.Tensor, PackedSeqParams]: + """ + Preprocess packed sequences + CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 + gets second and second last chunks, and so on), this is for load balancing with causal masking. + See https://github.com/NVIDIA/TransformerEngine/issues/1368 + """ + batch_size = input_ids.shape[0] + + seqlens_in_batch = ( + attention_mask.sum(dim=-1, dtype=torch.int32) + if attention_mask is not None + else torch.ones((batch_size), dtype=torch.int32, device=input_ids.device) * input_ids.shape[1] + ) + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + align_size = tp_size * cp_size * 2 if cp_size > 1 else tp_size + if use_fp8_padding: + # if fp8 is enabled, ensure the sequence is padded to multiples of 16 for better performance + original_align_size = align_size + align_size = math.lcm(16, align_size) + + pad_size = (align_size - seqlens_in_batch % align_size) % align_size + seqlens_in_batch_padded = seqlens_in_batch + pad_size + + cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0) + cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device) + cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0) + + if use_fp8_padding: + # make sure all the sequences are padded to multiples of 128 for TE compatibility + align_size_last = original_align_size * 128 + pad_size_last = (align_size_last - cu_seqlens_padded[-1] % align_size_last) % align_size_last + cu_seqlens_padded[-1] += pad_size_last + seqlens_in_batch_padded[-1] += pad_size_last + + # ---------------------------------------------------------------------------- + # Move the index information needed in the subsequent loop to the CPU at once, + # to avoid frequent .item() calls in the loop that cause D2H synchronization + # ---------------------------------------------------------------------------- + seqlens_in_batch_cpu: list[int] = seqlens_in_batch.tolist() # original valid lengths + seqlens_in_batch_padded_cpu: list[int] = seqlens_in_batch_padded.tolist() # lengths after padding + cu_seqlens_padded_cpu: list[int] = cu_seqlens_padded.tolist() # start positions (after padding) + + # Pure Python int calculation to avoid further synchronization + max_seqlen_in_batch = max(seqlens_in_batch_padded_cpu) + + shape = list(input_ids.shape[1:]) + shape[0] = sum(seqlens_in_batch_padded_cpu) // cp_size + if pre_process: + input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device) + for i in range(batch_size): + # Use Python int, so no GPU→CPU sync in the loop + if cp_size <= 1: + seqlen = seqlens_in_batch_cpu[i] + start_idx = cu_seqlens_padded_cpu[i] + input_ids_rmpad[start_idx : start_idx + seqlen] = input_ids[i, attention_mask[i]] + continue + + seqlen_padded_i = seqlens_in_batch_padded_cpu[i] + seqlen = seqlen_padded_i // cp_size + half_seqlen = seqlen // 2 + start_idx = cu_seqlens_padded_cpu[i] // cp_size + # split to 2 chunks + d = input_ids[i, attention_mask[i]] + input_ids_rmpad[start_idx : start_idx + half_seqlen] = d[ + half_seqlen * cp_rank : half_seqlen * (cp_rank + 1) + ] + + remain_start = seqlen_padded_i - half_seqlen * (cp_rank + 1) + remain_end = seqlen_padded_i - half_seqlen * cp_rank + remain_end = min(remain_end, d.shape[0]) + remain_len = remain_end - remain_start + if remain_len > 0: + input_ids_rmpad[start_idx + half_seqlen : start_idx + half_seqlen + remain_len] = d[ + remain_start:remain_end + ] + + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens_padded, + max_seqlen_q=max_seqlen_in_batch, + cu_seqlens_kv=cu_seqlens_padded, + max_seqlen_kv=max_seqlen_in_batch, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + ) + if pre_process: + return input_ids_rmpad.unsqueeze(0), packed_seq_params else: - return PackedSeqParams( - cu_seqlens_q=cu_seqlens_padded, - cu_seqlens_kv=cu_seqlens_padded, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) + return input_ids, packed_seq_params + + +def postprocess_packed_seqs( + output: torch.Tensor, + packed_seq_params: PackedSeqParams, + attention_mask: torch.Tensor, + batch_size: int, + seq_len: int, + post_process: bool = True, +) -> torch.Tensor: + """ + Postprocess packed sequences. + This function is not used for now. + """ + if not post_process: + return output + + # ------------------------------------------------------------------------- + # Move the lengths and offsets needed for subsequent Python-level indexing to the CPU in advance, + # to avoid a large number of .item() calls in the loop + # ------------------------------------------------------------------------- + cu_padded_cpu: list[int] = packed_seq_params.cu_seqlens_q_padded.tolist() + seq_lens_cpu: list[int] = attention_mask.sum(dim=1, dtype=torch.int32).cpu().tolist() + + shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim + output_new = torch.zeros(shape, dtype=output.dtype, device=output.device) + + cp_size = mpu.get_context_parallel_world_size() + # all gather output across context parallel group + if cp_size > 1: + # output shape: [1, packed_len, hidden_dim] + # need to gather across cp group and concatenate in sequence dimension + output_list = [torch.empty_like(output, dtype=output.dtype) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group()) + output_list[mpu.get_context_parallel_rank()] = output + else: + output_list = [output] + for i in range(batch_size): + if cp_size <= 1: + s = seq_lens_cpu[i] + start_idx = cu_padded_cpu[i] + output_new[i, attention_mask[i]] = output[0][start_idx : start_idx + s] + continue + s_len_padded_chunk = (cu_padded_cpu[i + 1] - cu_padded_cpu[i]) // cp_size + half_seqlen = s_len_padded_chunk // 2 + s_len = seq_lens_cpu[i] + s_len_padded = s_len_padded_chunk * cp_size + tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device, dtype=output.dtype) + for j in range(cp_size): + o = output_list[j][0] + # split to 2 chunks + packed_start_idx = cu_padded_cpu[i] // cp_size + o0, o1 = ( + o[packed_start_idx : packed_start_idx + half_seqlen], + o[packed_start_idx + half_seqlen : packed_start_idx + s_len_padded_chunk], + ) + tmp[j * half_seqlen : (j + 1) * half_seqlen] = o0 + tmp[s_len_padded - (j + 1) * half_seqlen : s_len_padded - j * half_seqlen] = o1 + output_new[i, attention_mask[i]] = tmp[:s_len] + + return output_new diff --git a/src/megatron/bridge/training/utils/train_utils.py b/src/megatron/bridge/training/utils/train_utils.py index acd9bc5d3c..040c929e74 100644 --- a/src/megatron/bridge/training/utils/train_utils.py +++ b/src/megatron/bridge/training/utils/train_utils.py @@ -580,6 +580,7 @@ def training_log( num_layers=layers, moe_layer_freq=config.model.moe_layer_freq, mtp_num_layers=config.model.mtp_num_layers, + pg_collection=pg_collection, ) if config.model.mtp_num_layers is not None: mtp_loss_scale = 1 / get_num_microbatches() diff --git a/src/megatron/bridge/training/vlm_step.py b/src/megatron/bridge/training/vlm_step.py index 1e013eeebd..34bd3e1a23 100644 --- a/src/megatron/bridge/training/vlm_step.py +++ b/src/megatron/bridge/training/vlm_step.py @@ -121,8 +121,8 @@ def get_batch( tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens, cu_seqlens_argmin, max_seqlen, visual_inputs (container of optional modalities) """ - is_first = is_pp_first_stage(pg_collection.pp) - is_last = is_pp_last_stage(pg_collection.pp) + is_first = True + is_last = True # All PP stages load from iterator to get input_ids and visual grid info # This allows each stage to compute MRoPE position_ids locally without broadcasting @@ -142,7 +142,7 @@ def get_batch( batch[k] = v # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length - if getattr(cfg.model, "pipeline_model_parallel_size", 1) > 1: + if True: seq_len = cfg.model.seq_length tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") diff --git a/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py b/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py index 68db2aba13..37e310dc08 100644 --- a/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py +++ b/tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_model.py @@ -29,6 +29,7 @@ import torch.nn.functional as F from megatron.core import parallel_state from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from PIL import Image from transformers import AutoProcessor, Qwen3VLMoeConfig @@ -253,6 +254,7 @@ def get_data_batch(processor, random_image): def test_model_freeze_api(self, freeze_all, hf_config): """Test model freeze API.""" self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() vision_transformer_config = self.get_vision_transformer_config(hf_config) language_transformer_config = self.get_language_transformer_config(hf_config) @@ -267,6 +269,7 @@ def test_model_freeze_api(self, freeze_all, hf_config): post_process=True, add_encoder=True, add_decoder=True, + pg_collection=pg_collection, ) if torch.cuda.is_available(): @@ -284,7 +287,8 @@ def test_model_freeze_api(self, freeze_all, hf_config): @pytest.mark.timeout(50) def test_shared_embedding_or_output_weight(self, hf_config): """Test shared_embedding_or_output_weight method.""" - self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() vision_transformer_config = self.get_vision_transformer_config(hf_config) language_transformer_config = self.get_language_transformer_config(hf_config) @@ -300,6 +304,7 @@ def test_shared_embedding_or_output_weight(self, hf_config): post_process=True, add_encoder=True, add_decoder=True, + pg_collection=pg_collection, ) weight = model.shared_embedding_or_output_weight() @@ -324,6 +329,7 @@ def test_shared_embedding_or_output_weight(self, hf_config): def test_set_input_tensor(self, hf_config): """Test set_input_tensor method.""" self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() vision_transformer_config = self.get_vision_transformer_config(hf_config) language_transformer_config = self.get_language_transformer_config(hf_config) @@ -339,6 +345,7 @@ def test_set_input_tensor(self, hf_config): post_process=True, add_encoder=True, add_decoder=True, + pg_collection=pg_collection, ) if torch.cuda.is_available():