diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 40a30d2bd3..f1c83adcbd 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -363,6 +363,7 @@ jobs: - script: L2_Launch_models_qwen - script: L2_Launch_models_qwen_quantization - script: L2_Launch_models_qwen_vl + - script: L2_Launch_models_qwen35_vl - script: L2_Launch_recipes_gemma_vl - script: L2_Launch_recipes_gpt_oss - script: L2_Launch_models_qwen_vl_quantization diff --git a/examples/conversion/compare_hf_and_megatron/compare.py b/examples/conversion/compare_hf_and_megatron/compare.py index 60ce377cd1..882ea8ff94 100644 --- a/examples/conversion/compare_hf_and_megatron/compare.py +++ b/examples/conversion/compare_hf_and_megatron/compare.py @@ -91,6 +91,7 @@ """ import argparse +import gc import importlib import os import sys @@ -318,7 +319,13 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: def loss_func(x, **kwargs): return x - return model(**forward_args), loss_func + model_output = model(**forward_args) + if isinstance(model_output, tuple): + output_tensor, _ = model_output + else: + output_tensor = model_output + + return output_tensor, loss_func def load_image(image_path: str) -> Image.Image: @@ -609,6 +616,11 @@ def _load_megatron_model(args): model_provider.finalize() megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False) + # Workaround: disable MTP for inference (causes hangs on NCCL collectives) + for m in megatron_model: + m.config.mtp_num_layers = None + m.config.grad_scale_func = None + model_components = [m.eval() for m in megatron_model] # Register debug hooks if enabled @@ -715,11 +727,10 @@ def compare_models_one_step(args) -> None: ) del hf_model - # Reload Megatron model to ensure a fresh instance before comparison - megatron_model, _ = _load_megatron_model(args) + gc.collect() + torch.cuda.empty_cache() - # Broadcast HF results to all ranks after Megatron initialization - # (following the pattern from generate_from_hf.py) + # Broadcast HF results to all ranks if torch.distributed.is_initialized(): # Create tensors for broadcasting if they don't exist on non-rank-0 if hf_next_token is None: @@ -731,6 +742,9 @@ def compare_models_one_step(args) -> None: ) hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.float32) + # Ensure consistent dtype across ranks before broadcast + hf_logits = hf_logits.float() + # Broadcast from rank 0 to all ranks torch.distributed.broadcast(hf_next_token, 0) torch.distributed.broadcast(hf_logits, 0) @@ -778,7 +792,10 @@ def compare_models_one_step(args) -> None: megatron_logits = megatron_output[0, -1, :] megatron_next_token = torch.argmax(megatron_logits, dim=-1) - if not torch.distributed.is_initialized() or parallel_state.get_tensor_model_parallel_rank() == 0: + if not torch.distributed.is_initialized() or ( + parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_expert_model_parallel_rank() == 0 + ): print(f"Megatron output shape: {megatron_output.shape}") print(f"Megatron logits stats - mean: {megatron_logits.mean():.4f}, std: {megatron_logits.std():.4f}") print( diff --git a/examples/conversion/hf_megatron_roundtrip_multi_gpu.py b/examples/conversion/hf_megatron_roundtrip_multi_gpu.py index eebb8af8e2..df87867029 100644 --- a/examples/conversion/hf_megatron_roundtrip_multi_gpu.py +++ b/examples/conversion/hf_megatron_roundtrip_multi_gpu.py @@ -62,6 +62,8 @@ # These are compared in float32 to avoid false mismatches. IGNORE_PRECISION_PARAMS = [ "e_score_correction_bias", + "A_log", + "linear_attn.norm.weight", ] diff --git a/examples/models/vlm/qwen35_vl/conversion.sh b/examples/models/vlm/qwen35_vl/conversion.sh new file mode 100755 index 0000000000..b7bcd54ad3 --- /dev/null +++ b/examples/models/vlm/qwen35_vl/conversion.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace directory for checkpoints and results +WORKSPACE=${WORKSPACE:-/workspace} +MODEL_NAME=Qwen3.5-35B-A3B # Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-397B-A17B, Qwen3.5-27B + +if [ "${MODEL_NAME}" = "Qwen3.5-27B" ]; then + HF_MODEL_CLASS="Qwen3_5ForConditionalGeneration" +else + HF_MODEL_CLASS="Qwen3_5MoeForConditionalGeneration" +fi + +# Make sure to upgrade to transformers >= 5.2.0 +# uv add transformers>=5.2.0 + +# Import HF → Megatron +uv run python examples/conversion/convert_checkpoints.py import \ + --hf-model Qwen/${MODEL_NAME} \ + --megatron-path ${WORKSPACE}/${MODEL_NAME} \ + --torch-dtype bfloat16 + +# HF and Megatron models logits comparison validation +uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/compare_hf_and_megatron/compare.py \ + --hf_model_path Qwen/${MODEL_NAME} \ + --megatron_model_path ${WORKSPACE}/${MODEL_NAME} \ + --model_class "${HF_MODEL_CLASS}" \ + --image_path "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" \ + --prompt "Describe this image." \ + --tp 1 --pp 1 --ep 8 + +# Export Megatron → HF +uv run python examples/conversion/convert_checkpoints.py export \ + --hf-model Qwen/${MODEL_NAME} \ + --megatron-path ${WORKSPACE}/${MODEL_NAME}/iter_0000000 \ + --hf-path ${WORKSPACE}/${MODEL_NAME}-hf-export + +# Round-trip validation +uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_megatron_roundtrip_multi_gpu.py \ + --hf-model-id Qwen/${MODEL_NAME} --tp 1 --pp 2 --ep 4 --trust-remote-code diff --git a/examples/models/vlm/qwen35_vl/inference.sh b/examples/models/vlm/qwen35_vl/inference.sh new file mode 100755 index 0000000000..17cfbec635 --- /dev/null +++ b/examples/models/vlm/qwen35_vl/inference.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace directory for checkpoints and results +WORKSPACE=${WORKSPACE:-/workspace} +MODEL_NAME=Qwen3.5-35B-A3B # Qwen3.5-35B-A3B, Qwen3.5-122B-A10B, Qwen3.5-27B + +# Inference with Hugging Face checkpoints +uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ + --hf_model_path Qwen/${MODEL_NAME} \ + --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ + --prompt "Describe this image." \ + --max_new_tokens 50 \ + --tp 2 --pp 2 --ep 4 + +# Inference with imported Megatron checkpoints +uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ + --hf_model_path Qwen/${MODEL_NAME} \ + --megatron_model_path ${WORKSPACE}/${MODEL_NAME}/iter_0000000 \ + --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ + --prompt "Describe this image." \ + --max_new_tokens 50 \ + --tp 2 --pp 2 --ep 4 + +# Inference with exported HF checkpoints +uv run python -m torch.distributed.run --nproc_per_node=8 examples/conversion/hf_to_megatron_generate_vlm.py \ + --hf_model_path ${WORKSPACE}/${MODEL_NAME}-hf-export \ + --image_path "https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" \ + --prompt "Describe this image." \ + --max_new_tokens 50 \ + --tp 2 --pp 2 --ep 4 diff --git a/examples/models/vlm/qwen35_vl/slurm_inference.sh b/examples/models/vlm/qwen35_vl/slurm_inference.sh new file mode 100755 index 0000000000..ce52f59e1c --- /dev/null +++ b/examples/models/vlm/qwen35_vl/slurm_inference.sh @@ -0,0 +1,180 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================== +# Qwen3.5-VL Multi-Node Distributed Inference for Qwen3.5-397B-A17B +# Recommended: TP=2, PP=4, EP=8 for full model (32 GPUs, 4 nodes) +# +# Usage: +# 1. Modify the #SBATCH directives below for your cluster +# 2. Set MODEL_PATH and CHECKPOINT_PATH as needed +# 3. Set CONTAINER_IMAGE or use --no-container-image for bare metal +# 4. Submit: sbatch slurm_inference.sh +# ============================================================================== + +#SBATCH --job-name=qwen35v-inference +#SBATCH --nodes=4 # Number of nodes (32 GPUs = 4 nodes × 8 GPUs) +#SBATCH --ntasks-per-node=8 # Tasks per node (1 per GPU) +#SBATCH --gpus-per-node=8 # GPUs per node +#SBATCH --time=02:00:00 # Max run time (2 hours) +#SBATCH --partition=gpu # Partition name +#SBATCH --account=my_account # Account name +#SBATCH --output=logs/qwen35v_inference_%j.out +#SBATCH --error=logs/qwen35v_inference_%j.err +#SBATCH --exclusive # Exclusive node access + +# ============================================================================== +# CONFIGURATION +# ============================================================================== + +# Workspace directory +WORKSPACE=${WORKSPACE:-/workspace} + +# Model configuration +MODEL_NAME=Qwen3.5-397B-A17B + +# Option 1: Use HuggingFace model path (will load and convert on-the-fly) +MODEL_PATH=${WORKSPACE}/${MODEL_NAME} +# MODEL_PATH=Qwen/${MODEL_NAME} # Or use HF Hub path + +# Option 2: Use pre-converted Megatron checkpoint (faster) +MEGATRON_CHECKPOINT=${WORKSPACE}/${MODEL_NAME}/iter_0000000 +# Comment out to use HF model directly + +# Inference configuration +IMAGE_PATH="https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16/resolve/main/images/table.png" +PROMPT="Describe this image." +MAX_NEW_TOKENS=1000 + +# Parallelism configuration for 32 GPUs (4 nodes × 8 GPUs) +TP=2 # Tensor Parallelism +PP=4 # Pipeline Parallelism +EP=8 # Expert Parallelism (MoE) + +# Container configuration (required for SLURM pyxis) +CONTAINER_IMAGE="" +# CONTAINER_IMAGE="/path/to/nemo-framework.sqsh" + +# Container mounts (optional, space-separated) +CONTAINER_MOUNTS="" +# CONTAINER_MOUNTS="/data:/data /workspace:/workspace" + +# Set to true to run without container (bare metal) +NO_CONTAINER=false + +# ============================================================================== +# Environment Setup +# ============================================================================== + +# NCCL optimizations +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_NVLS_ENABLE=0 + +# UV cache on shared filesystem (recommended for multi-node setups) +# Pre-sync once before submitting jobs: UV_CACHE_DIR=/path/to/cache uv sync +# export UV_CACHE_DIR="/path/to/shared/uv_cache" + +# HuggingFace cache directory (recommended for shared filesystem) +# export HF_HOME="/path/to/shared/HF_HOME" + +# Authentication tokens +# export HF_TOKEN="hf_your_token_here" + +# Make sure to upgrade container image to transformers >= 5.2.0 (required for Qwen3.5) +# Run once: uv add "transformers>=5.2.0" + +# ============================================================================== +# Job Execution +# ============================================================================== + +echo "======================================" +echo "Qwen3.5-VL Multi-Node Inference" +echo "======================================" +echo "Job ID: $SLURM_JOB_ID" +echo "Nodes: $SLURM_JOB_NUM_NODES" +echo "GPUs per node: $SLURM_GPUS_PER_NODE" +echo "Total GPUs: $((SLURM_JOB_NUM_NODES * SLURM_GPUS_PER_NODE))" +echo "Model: $MODEL_NAME" +echo "Parallelism: TP=$TP, PP=$PP, EP=$EP" +echo "======================================" + +# Create logs directory +mkdir -p logs + +# Calculate total processes +TOTAL_GPUS=$((SLURM_JOB_NUM_NODES * SLURM_GPUS_PER_NODE)) +REQUIRED_GPUS=$(( (TP > EP ? TP : EP) * PP )) + +# Validate parallelism configuration +if [ $REQUIRED_GPUS -ne $TOTAL_GPUS ]; then + echo "ERROR: Parallelism mismatch!" + echo " max(TP, EP) × PP = max($TP, $EP) × $PP = $REQUIRED_GPUS" + echo " Total allocated GPUs = $TOTAL_GPUS" + echo " These must be equal!" + exit 1 +fi + +MEGATRON_CKPT_ARG="" +if [ -n "$MEGATRON_CHECKPOINT" ]; then + MEGATRON_CKPT_ARG="--megatron_model_path $MEGATRON_CHECKPOINT" +fi + +CMD="uv run --no-sync python examples/conversion/hf_to_megatron_generate_vlm.py \ + --hf_model_path $MODEL_PATH \ + $MEGATRON_CKPT_ARG \ + --image_path \"$IMAGE_PATH\" \ + --prompt \"$PROMPT\" \ + --max_new_tokens $MAX_NEW_TOKENS \ + --tp $TP \ + --pp $PP \ + --ep $EP" + +# Only rank 0 on each node runs uv sync +SYNC_CMD="if [ \"\$SLURM_LOCALID\" -eq 0 ]; then uv sync; else sleep 5; fi" +FULL_CMD="$SYNC_CMD && $CMD" + +echo "Executing inference..." +echo "Command: $CMD" +echo "======================================" + +# Execute based on container configuration +if [ "$NO_CONTAINER" = true ]; then + echo "Running without container (bare metal)" + srun --mpi=pmix bash -c "$FULL_CMD" +else + # Require container image + if [ -z "$CONTAINER_IMAGE" ]; then + echo "ERROR: CONTAINER_IMAGE must be set, or use NO_CONTAINER=true for bare metal." + exit 1 + fi + + echo "Running with container: $CONTAINER_IMAGE" + + # Build srun command with container + SRUN_CMD="srun --mpi=pmix --container-image=$CONTAINER_IMAGE" + + # Add container mounts + if [ -n "$CONTAINER_MOUNTS" ]; then + for mount in $CONTAINER_MOUNTS; do + SRUN_CMD="$SRUN_CMD --container-mounts=$mount" + done + fi + + $SRUN_CMD bash -c "$FULL_CMD" +fi + +echo "======================================" +echo "Inference completed" +echo "======================================" diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index 4fd486022c..185ec26504 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -182,6 +182,10 @@ Qwen25VLBridge, Qwen25VLModel, Qwen25VLModelProvider, + Qwen35VLBridge, + Qwen35VLModelProvider, + Qwen35VLMoEBridge, + Qwen35VLMoEModelProvider, ) from megatron.bridge.models.qwen_vl.modelling_qwen3_vl import ( Qwen3VLBridge, @@ -331,6 +335,10 @@ "Qwen3VLMoEModelProvider", "Qwen3VLBridge", "Qwen3VLMoEBridge", + "Qwen35VLBridge", + "Qwen35VLModelProvider", + "Qwen35VLMoEBridge", + "Qwen35VLMoEModelProvider", "Gemma3VLBridge", "Gemma3VLModel", "Gemma3VLModelProvider", diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index 5ac282df88..d0c0d31853 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -1813,6 +1813,109 @@ def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": ) +class GDNLinearMappingSeparate(MegatronParamMapping[Dict[str, torch.Tensor]]): + """GDN input projection mapping for models with separate QKV, Z, B, A HF weights. + + Unlike :class:`GDNLinearMapping` which expects two fused tensors (``in_proj_qkvz`` + and ``in_proj_ba`` in Qwen3-Next's head-grouped layout), this mapping handles + models that store each projection component separately: + + * ``in_proj_qkv`` - fused Q, K, V projection (flat ``[Q; K; V]``) + * ``in_proj_z`` - Z (gate) projection + * ``in_proj_b`` - B (beta) projection + * ``in_proj_a`` - A (alpha) projection + + Used by **Qwen3.5** whose GDN layers expose four distinct weight matrices. + + The class converts between the 4-tensor HF layout and Megatron's single + ``in_proj`` tensor by first assembling the head-grouped ``qkvz`` / ``ba`` + intermediates expected by the existing :func:`merge_gdn_linear_weights` and + :func:`split_gdn_linear_weights` helpers, keeping the TP-sharding logic + unchanged. + """ + + def __init__(self, megatron_param: str, qkv: str, z: str, b: str, a: str): + """Initialise GDN separate-component mapping. + + Args: + megatron_param: Megatron ``in_proj`` parameter name pattern. + qkv: HF weight pattern for the fused Q/K/V projection. + z: HF weight pattern for the Z (gate) projection. + b: HF weight pattern for the B (beta) projection. + a: HF weight pattern for the A (alpha) projection. + """ + super().__init__(megatron_param, {"qkv": qkv, "z": z, "b": b, "a": a}) + self._tp_mapping = AutoMapping(megatron_param, megatron_param) + + # --------------------------------------------------------------------- # + # HF → Megatron + # --------------------------------------------------------------------- # + def hf_to_megatron( + self, + hf_weights: Dict[str, torch.Tensor], + megatron_module: nn.Module, + ) -> torch.Tensor: + """Merge four separate HF tensors into Megatron's single ``in_proj``.""" + if self.tp_rank == 0: + config = self._get_config(megatron_module) + qkvz, ba = _fuse_gdn_separate_to_grouped( + config, hf_weights["qkv"], hf_weights["z"], hf_weights["b"], hf_weights["a"] + ) + merged = merge_gdn_linear_weights(config, qkvz, ba, tp_size=self.tp_size) + else: + merged = None + + return self._tp_mapping.hf_to_megatron(merged, megatron_module) + + # --------------------------------------------------------------------- # + # Megatron → HF + # --------------------------------------------------------------------- # + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[nn.Module], + ) -> Dict[str, torch.Tensor]: + """Gather shards and split into the four separate HF tensors.""" + if megatron_weights is not None: + megatron_weights = self.maybe_dequantize(megatron_weights) + + # Broadcast config across PP ranks (mirrors GDNLinearMapping). + if megatron_module is None: + config = self.broadcast_obj_from_pp_rank(None) + else: + config = self._get_config(megatron_module) + config = remove_non_pickleables(config, max_depth=3) + config = self.broadcast_obj_from_pp_rank(config) + + packed_dict = self._tp_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not packed_dict: + return {} + + packed_in_proj = next(iter(packed_dict.values())) + qkvz, ba = split_gdn_linear_weights(config, packed_in_proj, tp_size=self.tp_size) + qkv, z, b, a = _split_gdn_grouped_to_separate(config, qkvz, ba) + + return { + self.hf_param["qkv"]: qkv, + self.hf_param["z"]: z, + self.hf_param["b"]: b, + self.hf_param["a"]: a, + } + + # --------------------------------------------------------------------- # + # Pattern resolution + # --------------------------------------------------------------------- # + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)( + resolved_megatron_param, + resolved_hf_param["qkv"], + resolved_hf_param["z"], + resolved_hf_param["b"], + resolved_hf_param["a"], + ) + + class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]): """ Mapping for interleaved Query/Key/Value attention projection weights. @@ -2431,6 +2534,115 @@ def split_gdn_linear_weights(provider: TransformerConfig, in_proj: torch.Tensor, return qkvz, ba +def _fuse_gdn_separate_to_grouped( + config: TransformerConfig, + qkv: torch.Tensor, + z: torch.Tensor, + b: torch.Tensor, + a: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert four separate (flat) GDN projection tensors into the head-grouped + ``qkvz`` and ``ba`` format expected by :func:`merge_gdn_linear_weights`. + + Args: + config: Transformer configuration with GDN head dimensions. + qkv: Flat ``[Q; K; V]`` tensor of shape ``(qk_dim*2 + v_dim, hidden)``. + z: Z projection of shape ``(v_dim, hidden)``. + b: B projection of shape ``(num_v_heads, hidden)``. + a: A projection of shape ``(num_v_heads, hidden)``. + + Returns: + Tuple of (qkvz, ba) in head-grouped layout that + :func:`merge_gdn_linear_weights` can consume directly. + """ + hidden_size = config.hidden_size + qk_head_dim = config.linear_key_head_dim + v_head_dim = config.linear_value_head_dim + num_qk_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + qk_dim = qk_head_dim * num_qk_heads + v_dim = v_head_dim * num_v_heads + v_per_group = num_v_heads // num_qk_heads + + expected_qkv = (qk_dim * 2 + v_dim, hidden_size) + expected_z = (v_dim, hidden_size) + expected_ba = (num_v_heads, hidden_size) + if tuple(qkv.shape) != expected_qkv: + raise ValueError(f"qkv shape mismatch: expected {expected_qkv}, got {tuple(qkv.shape)}") + if tuple(z.shape) != expected_z: + raise ValueError(f"z shape mismatch: expected {expected_z}, got {tuple(z.shape)}") + if tuple(b.shape) != expected_ba: + raise ValueError(f"b shape mismatch: expected {expected_ba}, got {tuple(b.shape)}") + if tuple(a.shape) != expected_ba: + raise ValueError(f"a shape mismatch: expected {expected_ba}, got {tuple(a.shape)}") + + # --- Split flat QKV into individual components --- + q_flat, k_flat, v_flat = torch.split(qkv, [qk_dim, qk_dim, v_dim], dim=0) + + # --- Reshape every component to (num_qk_heads, per_group_dim, hidden) --- + q_g = q_flat.reshape(num_qk_heads, qk_head_dim, hidden_size) + k_g = k_flat.reshape(num_qk_heads, qk_head_dim, hidden_size) + v_g = v_flat.reshape(num_qk_heads, v_per_group * v_head_dim, hidden_size) + z_g = z.reshape(num_qk_heads, v_per_group * v_head_dim, hidden_size) + b_g = b.reshape(num_qk_heads, v_per_group, hidden_size) + a_g = a.reshape(num_qk_heads, v_per_group, hidden_size) + + # --- Assemble grouped qkvz and ba --- + qkvz = torch.cat([q_g, k_g, v_g, z_g], dim=1).reshape(-1, hidden_size) + ba = torch.cat([b_g, a_g], dim=1).reshape(-1, hidden_size) + + return qkvz, ba + + +def _split_gdn_grouped_to_separate( + config: TransformerConfig, + qkvz: torch.Tensor, + ba: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert head-grouped ``qkvz`` and ``ba`` tensors (as produced by + :func:`split_gdn_linear_weights`) back into four flat tensors. + + Returns: + Tuple of (qkv, z, b, a) where each tensor has a flat per-component layout. + """ + hidden_size = config.hidden_size + qk_head_dim = config.linear_key_head_dim + v_head_dim = config.linear_value_head_dim + num_qk_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + v_per_group = num_v_heads // num_qk_heads + + expected_qkvz_dim0 = num_qk_heads * (qk_head_dim * 2 + v_per_group * v_head_dim * 2) + expected_ba_dim0 = num_qk_heads * v_per_group * 2 + if qkvz.ndim != 2 or qkvz.shape[0] != expected_qkvz_dim0 or qkvz.shape[1] != hidden_size: + raise ValueError( + f"qkvz shape mismatch: expected ({expected_qkvz_dim0}, {hidden_size}), got {tuple(qkvz.shape)}" + ) + if ba.ndim != 2 or ba.shape[0] != expected_ba_dim0 or ba.shape[1] != hidden_size: + raise ValueError(f"ba shape mismatch: expected ({expected_ba_dim0}, {hidden_size}), got {tuple(ba.shape)}") + + # --- Split grouped QKVZ --- + qkvz_g = qkvz.reshape(num_qk_heads, -1, hidden_size) + q_g, k_g, v_g, z_g = torch.split( + qkvz_g, + [qk_head_dim, qk_head_dim, v_per_group * v_head_dim, v_per_group * v_head_dim], + dim=1, + ) + q_flat = q_g.reshape(-1, hidden_size) + k_flat = k_g.reshape(-1, hidden_size) + v_flat = v_g.reshape(-1, hidden_size) + z_flat = z_g.reshape(-1, hidden_size) + qkv = torch.cat([q_flat, k_flat, v_flat], dim=0) + + # --- Split grouped BA --- + ba_g = ba.reshape(num_qk_heads, -1, hidden_size) + b_g, a_g = torch.split(ba_g, [v_per_group, v_per_group], dim=1) + b_flat = b_g.reshape(-1, hidden_size) + a_flat = a_g.reshape(-1, hidden_size) + + return qkv, z_flat, b_flat, a_flat + + def merge_kv_biases(config: TransformerConfig, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Merge separate K, V bias vectors into Megatron's interleaved KV format (1D).""" num_query_groups = config.num_query_groups diff --git a/src/megatron/bridge/models/qwen_vl/__init__.py b/src/megatron/bridge/models/qwen_vl/__init__.py index 8495eaef29..af9618d4ab 100644 --- a/src/megatron/bridge/models/qwen_vl/__init__.py +++ b/src/megatron/bridge/models/qwen_vl/__init__.py @@ -23,6 +23,8 @@ from megatron.bridge.models.qwen_vl.qwen25_vl_provider import ( Qwen25VLModelProvider, ) +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLBridge, Qwen35VLMoEBridge +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import Qwen35VLModelProvider, Qwen35VLMoEModelProvider __all__ = [ @@ -34,4 +36,8 @@ "Qwen3VLMoEBridge", "Qwen3VLModelProvider", "Qwen3VLMoEModelProvider", + "Qwen35VLBridge", + "Qwen35VLModelProvider", + "Qwen35VLMoEBridge", + "Qwen35VLMoEModelProvider", ] diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py index 38573d3dbb..3e1e0406cd 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py @@ -102,7 +102,13 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. nvtx_range_push(suffix="qkv") - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + gate = None + if self.config.attention_output_gate: + query, key, value, gate = self.get_query_key_value_tensors( + hidden_states, key_value_states, output_gate=True + ) + else: + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") # =================================================== @@ -131,6 +137,8 @@ def forward( ) out = output.transpose(0, 1).contiguous() context_layer = out.view(out.size(0), out.size(1), -1) + if gate is not None: + context_layer = self._apply_output_gate(context_layer, gate) output, bias = self.linear_proj(context_layer) return output, bias @@ -260,6 +268,10 @@ def forward( core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) nvtx_range_pop(suffix="core_attention") + # Output gate (for Gated Attention in hybrid architectures like Qwen3.5) + if gate is not None: + core_attn_out = self._apply_output_gate(core_attn_out, gate) + # ================= # Output. [sq, b, h] # ================= diff --git a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py index ba7f19425a..5d520e231e 100644 --- a/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py +++ b/src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py @@ -84,7 +84,8 @@ def __init__( ) -> None: super().__init__(config=language_transformer_config) - language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + if hasattr(language_transformer_layer_spec, "submodules"): + language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention self.pre_process = pre_process self.post_process = post_process @@ -164,11 +165,10 @@ def __init__( pg_collection=pg_collection, ) if pre_process: - assert len(vision_transformer_config.deepstack_visual_indexes) <= len( - self.language_model.decoder.layers - ), ( + deepstack_indexes = getattr(vision_transformer_config, "deepstack_visual_indexes", []) + assert len(deepstack_indexes) <= len(self.language_model.decoder.layers), ( "the deepstack_visual_embeds should on the first pp-stage of language model", - f"got {len(vision_transformer_config.deepstack_visual_indexes)} deepstack_visual_indexes, " + f"got {len(deepstack_indexes)} deepstack_visual_indexes, " f" {len(self.language_model.decoder.layers)} language model layers", ) diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py new file mode 100644 index 0000000000..fd26808bc5 --- /dev/null +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py @@ -0,0 +1,614 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Megatron Bridges for Qwen3.5 Vision-Language Models. + +Qwen3.5 is a family of multimodal models that combine: +- A hybrid Gated DeltaNet + Gated Attention language model (like Qwen3-Next) +- A vision encoder (similar to Qwen3-VL) +- Dense MLP or Mixture of Experts (MoE) with shared experts + +This module provides two bridges: + +- ``Qwen35VLBridge``: Dense variant (e.g., Qwen3.5-27B) + Reference: https://huggingface.co/Qwen/Qwen3.5-27B + +- ``Qwen35VLMoEBridge``: MoE variant (e.g., Qwen3.5-397B-A17B) + Reference: https://huggingface.co/Qwen/Qwen3.5-397B-A17B +""" + +import logging +import os + +import torch + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + ConcatenatedQKVMapping, + GatedMLPMapping, + GDNConv1dMapping, + GDNLinearMappingSeparate, + QKVMapping, + ReplicatedMapping, + RMSNorm2ZeroCenteredRMSNormMapping, +) +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel +from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import ( + ExpertMLPDownProjMapping, + ExpertMLPGateUpProjMapping, + Qwen3VLMoEBridge, +) +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, +) + + +logger = logging.getLogger(__name__) + +_QWEN3_5_DENSE_HF_CLASS_NAME = "Qwen3_5ForConditionalGeneration" +_QWEN3_5_MOE_HF_CLASS_NAME = "Qwen3_5MoeForConditionalGeneration" + + +@MegatronModelBridge.register_bridge( + source=_QWEN3_5_MOE_HF_CLASS_NAME, + target=Qwen3VLModel, + provider=Qwen35VLMoEModelProvider, + model_type="qwen3_5_moe", +) +class Qwen35VLMoEBridge(Qwen3VLMoEBridge): + """ + Megatron Bridge for Qwen3.5 Vision-Language Model. + + This bridge handles the conversion between HuggingFace Qwen3.5 VL model + and Megatron-Core Qwen3VLModel formats, including weight mappings and + configuration translation for the hybrid GDN+Attention VLM architecture. + + The weight mappings handle: + - Language model hybrid layers (GDN + standard attention) + - MoE layers with routed and shared experts + - Vision model weights (same as Qwen3-VL: deepstack, merger, patch embed) + - QK layernorm, zero-centered RMSNorm for GDN output norm + - mRoPE position embeddings + + Architecture: 15 × (3 × (GDN → MoE) + 1 × (Attention → MoE)) = 60 layers + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3.5-397B-A17B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen35VLMoEModelProvider: + """ + Create a Qwen35VLMoEModelProvider from a HuggingFace pretrained model. + + Extracts both language model and vision model configurations from the + HuggingFace config and maps them to Megatron provider parameters. + + Args: + hf_pretrained: HuggingFace pretrained VLM model + + Returns: + Qwen35VLMoEModelProvider configured with the HF model's parameters + """ + hf_config = hf_pretrained.config + text_config = hf_config.text_config + + # Use base class utility to extract common config fields + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) + + vision_config = hf_config.vision_config + vision_config.torch_dtype = provider_kwargs.get("params_dtype", torch.float32) + + provider = Qwen35VLMoEModelProvider(**provider_kwargs) + + # --- Common Qwen3 LLM settings --- + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.add_qkv_bias = getattr(text_config, "attention_bias", False) + provider.add_bias_linear = False + provider.qk_layernorm = True + provider.hidden_dropout = 0.0 + + # --- Qwen3-Next hybrid architecture settings --- + provider.layernorm_zero_centered_gamma = True + provider.attention_output_gate = True + provider.experimental_attention_variant = "gated_delta_net" + # full_attention_interval defines how often standard attention appears: + # e.g., 4 means every 4th layer is standard attention (3 GDN + 1 Attn) + provider.linear_attention_freq = getattr(text_config, "full_attention_interval", 4) + provider.rotary_percent = getattr(text_config, "rope_parameters", {}).get("partial_rotary_factor", 0.25) + + # --- MoE specific parameters --- + provider.moe_ffn_hidden_size = getattr(text_config, "moe_intermediate_size", 1024) + provider.num_moe_experts = getattr(text_config, "num_experts", 512) + provider.moe_router_topk = getattr(text_config, "num_experts_per_tok", 10) + provider.moe_shared_expert_intermediate_size = getattr(text_config, "shared_expert_intermediate_size", None) + provider.moe_shared_expert_gate = True + provider.moe_grouped_gemm = True + provider.moe_router_load_balancing_type = "global_aux_loss" + provider.moe_router_pre_softmax = False + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_permute_fusion = True + + # --- GDN (Gated DeltaNet) specific parameters --- + provider.linear_conv_kernel_dim = getattr(text_config, "linear_conv_kernel_dim", 4) + provider.linear_key_head_dim = getattr(text_config, "linear_key_head_dim", 128) + provider.linear_value_head_dim = getattr(text_config, "linear_value_head_dim", 128) + provider.linear_num_key_heads = getattr(text_config, "linear_num_key_heads", 16) + provider.linear_num_value_heads = getattr(text_config, "linear_num_value_heads", 64) + + # --- VL-specific overrides --- + provider.position_embedding_type = "mrope" + provider.vision_config = vision_config + provider.hf_text_config = text_config + provider.head_dim = getattr(text_config, "head_dim", 256) + provider.bos_token_id = getattr(text_config, "bos_token_id", 248045) + provider.eos_token_id = getattr(text_config, "eos_token_id", 248046) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 248053) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 248054) + provider.image_token_id = getattr(hf_config, "image_token_id", 248056) + provider.video_token_id = getattr(hf_config, "video_token_id", 248057) + provider.audio_token_id = getattr(hf_config, "audio_token_id", 248076) + + # Qwen3.5 uses mRoPE with [11, 11, 10] sections (different from Qwen3-VL's [24, 20, 20]) + # The sections correspond to [temporal, height, width] dimensions. + # With partial_rotary_factor=0.25 and head_dim=256, rotary_dim=64, + # so each pair needs 32 dims total → sections [11, 11, 10]. + provider.mrope_section = getattr(text_config, "rope_scaling", {}).get("mrope_section", [11, 11, 10]) + + # --- DEBUG: tiny model for quick testing --- + # Set QWEN35_DEBUG=1 to shrink the text model to 4 layers (1 GDN-GDN-GDN-Attn group) + # with fewer experts. Useful for conversion / forward-pass smoke tests. + if os.environ.get("QWEN35_DEBUG", "0") == "1": + logger.warning("QWEN35_DEBUG=1: overriding to tiny 4-layer model for debugging") + provider.num_layers = 4 # 3 GDN + 1 Attn (one full group) + # provider.num_moe_experts = 8 # 512 → 8 + # provider.moe_router_topk = 2 # 10 → 2 + # provider.moe_grouped_gemm = False + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """ + Return MegatronMappingRegistry containing parameter mappings for Qwen3.5 VL. + + Combines: + 1. Language model mappings (Qwen3-Next hybrid architecture with VL prefixes): + - Standard attention: QKV, output projection, QK layernorm + - Linear attention (GDN): in_proj, out_proj, conv1d, A_log, dt_bias, out_norm + - MoE: router, routed expert MLPs, shared expert MLPs, shared expert gate + - Embeddings, output layer, final layernorm + + 2. Vision model mappings (Qwen3-VL style): + - Vision transformer blocks: attention, MLP, layer norms + - Deepstack visual mergers + - Patch embedding and position embedding + - Final merger (patch_norm, linear_fc1, linear_fc2) + + Naming Convention: + - Megatron language model params are prefixed with "language_model." + - HF language model params are prefixed with "model.language_model." + - Megatron vision model params are prefixed with "vision_model." + - HF vision model params are prefixed with "model.visual." + + Returns: + MegatronMappingRegistry with all parameter mappings + """ + + # ===================================================================== + # Simple 1:1 parameter mappings + # ===================================================================== + param_mappings = { + # ================================================================= + # Language Model: Embeddings and output + # ================================================================= + "language_model.embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", + "language_model.output_layer.weight": "lm_head.weight", + "language_model.decoder.final_layernorm.weight": "model.language_model.norm.weight", + # ================================================================= + # Language Model: MoE router + # ================================================================= + "language_model.decoder.layers.*.mlp.router.weight": "model.language_model.layers.*.mlp.gate.weight", + "language_model.decoder.layers.*.pre_mlp_layernorm.weight": "model.language_model.layers.*.post_attention_layernorm.weight", + # ================================================================= + # Language Model: Standard attention layers (Gated Attention) + # These mappings apply to layers where standard attention is used + # (every 4th layer in the 15 × (3 GDN + 1 Attn) pattern) + # ================================================================= + "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.language_model.layers.*.input_layernorm.weight", + "language_model.decoder.layers.*.self_attention.q_layernorm.weight": "model.language_model.layers.*.self_attn.q_norm.weight", + "language_model.decoder.layers.*.self_attention.k_layernorm.weight": "model.language_model.layers.*.self_attn.k_norm.weight", + "language_model.decoder.layers.*.self_attention.linear_proj.weight": "model.language_model.layers.*.self_attn.o_proj.weight", + # ================================================================= + # Language Model: Linear attention (Gated DeltaNet) layers + # These mappings apply to layers where GDN is used + # (3 out of every 4 layers) + # ================================================================= + "language_model.decoder.layers.*.self_attention.in_proj.layer_norm_weight": "model.language_model.layers.*.input_layernorm.weight", + "language_model.decoder.layers.*.self_attention.out_proj.weight": "model.language_model.layers.*.linear_attn.out_proj.weight", + "language_model.decoder.layers.*.self_attention.A_log": "model.language_model.layers.*.linear_attn.A_log", + "language_model.decoder.layers.*.self_attention.dt_bias": "model.language_model.layers.*.linear_attn.dt_bias", + # ================================================================= + # Vision Model: Attention + # ================================================================= + "vision_model.decoder.layers.*.self_attention.linear_proj.weight": "model.visual.blocks.*.attn.proj.weight", + "vision_model.decoder.layers.*.self_attention.linear_proj.bias": "model.visual.blocks.*.attn.proj.bias", + # ================================================================= + # Vision Model: MLP + # ================================================================= + "vision_model.decoder.layers.*.mlp.linear_fc1.weight": "model.visual.blocks.*.mlp.linear_fc1.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.bias": "model.visual.blocks.*.mlp.linear_fc1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc2.weight": "model.visual.blocks.*.mlp.linear_fc2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc2.bias": "model.visual.blocks.*.mlp.linear_fc2.bias", + # ================================================================= + # Vision Model: Layer Norms + # ================================================================= + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.visual.blocks.*.norm1.weight", + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.visual.blocks.*.norm1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.visual.blocks.*.norm2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.visual.blocks.*.norm2.bias", + # ================================================================= + # Vision Model: Final Merger + # ================================================================= + "vision_model.merger.patch_norm.**": "model.visual.merger.norm.**", + "vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight", + "vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias", + "vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight", + "vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias", + } + + mapping_list = [] + + # Convert simple 1:1 mappings to AutoMapping objects + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Register module types for GDN and shared expert (needed for AutoMapping detection) + AutoMapping.register_module_type("SharedExpertMLP", "column") + AutoMapping.register_module_type("GatedDeltaNet", "column") + + # ===================================================================== + # Special mappings requiring parameter transformation + # ===================================================================== + mapping_list.extend( + [ + # ============================================================= + # Language Model: Standard Attention QKV + # Combines separate Q, K, V matrices into single QKV matrix + # ============================================================= + QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="model.language_model.layers.*.self_attn.q_proj.weight", + k="model.language_model.layers.*.self_attn.k_proj.weight", + v="model.language_model.layers.*.self_attn.v_proj.weight", + ), + # ============================================================= + # Language Model: GDN (Gated DeltaNet) specific mappings + # ============================================================= + # GDN Conv1d: depthwise causal convolution + GDNConv1dMapping( + megatron_param="language_model.decoder.layers.*.self_attention.conv1d.weight", + hf_param="model.language_model.layers.*.linear_attn.conv1d.weight", + ), + # GDN Input Projection: Qwen3.5 stores 4 separate weight tensors + # (in_proj_qkv, in_proj_z, in_proj_b, in_proj_a) instead of the + # 2 fused tensors (in_proj_qkvz, in_proj_ba) used by Qwen3-Next. + GDNLinearMappingSeparate( + megatron_param="language_model.decoder.layers.*.self_attention.in_proj.weight", + qkv="model.language_model.layers.*.linear_attn.in_proj_qkv.weight", + z="model.language_model.layers.*.linear_attn.in_proj_z.weight", + b="model.language_model.layers.*.linear_attn.in_proj_b.weight", + a="model.language_model.layers.*.linear_attn.in_proj_a.weight", + ), + # GDN Output Norm: zero-centered RMSNorm conversion + # Qwen3-Next uses standard RMSNorm initialized to ones for output norm, + # while Megatron uses zero-centered RMSNorm, so we subtract 1 during conversion. + RMSNorm2ZeroCenteredRMSNormMapping( + "language_model.decoder.layers.*.self_attention.out_norm.weight", + "model.language_model.layers.*.linear_attn.norm.weight", + ), + # ============================================================= + # Language Model: MoE Expert MLPs (routed experts) + # Uses GatedMLPMapping for gate+up projection fusion + # ============================================================= + ExpertMLPGateUpProjMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", + hf_param="model.language_model.layers.*.mlp.experts.gate_up_proj", + ), + ExpertMLPDownProjMapping( + megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*", + hf_param="model.language_model.layers.*.mlp.experts.down_proj", + ), + # ============================================================= + # Language Model: Shared Expert MLPs + # ============================================================= + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.shared_experts.linear_fc1.weight", + gate="model.language_model.layers.*.mlp.shared_expert.gate_proj.weight", + up="model.language_model.layers.*.mlp.shared_expert.up_proj.weight", + ), + AutoMapping( + megatron_param="language_model.decoder.layers.*.mlp.shared_experts.linear_fc2.weight", + hf_param="model.language_model.layers.*.mlp.shared_expert.down_proj.weight", + ), + # Shared expert gate weight (replicated across TP ranks) + ReplicatedMapping( + megatron_param="language_model.decoder.layers.*.mlp.shared_experts.gate_weight", + hf_param="model.language_model.layers.*.mlp.shared_expert_gate.weight", + ), + # ============================================================= + # Vision Model: QKV (concatenated format) + # ============================================================= + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", + hf_param="model.visual.blocks.*.attn.qkv.weight", + ), + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", + hf_param="model.visual.blocks.*.attn.qkv.bias", + ), + # ============================================================= + # Vision Model: Patch embedding (replicated across TP ranks) + # These are conv layers that must be replicated + # ============================================================= + ReplicatedMapping( + megatron_param="vision_model.patch_embed.proj.**", + hf_param="model.visual.patch_embed.proj.**", + ), + ReplicatedMapping( + megatron_param="vision_model.pos_embed.weight", + hf_param="model.visual.pos_embed.weight", + ), + ] + ) + + # TODO: MTP (Multi-Token Prediction) mappings for VL context. + # "language_model.mtp.layers.0.eh_proj.weight": "mtp.fc.weight", + # "language_model.mtp.layers.0.enorm.weight": "mtp.pre_fc_norm_embedding.weight", + # "language_model.mtp.layers.0.hnorm.weight": "mtp.pre_fc_norm_hidden.weight", + # "language_model.mtp.layers.0.final_layernorm.weight": "mtp.norm.weight", + # "language_model.mtp.layers.0.transformer_layer.mlp.router.weight": "mtp.layers.0.mlp.gate.weight", + # "language_model.mtp.layers.0.transformer_layer.pre_mlp_layernorm.weight": "mtp.layers.0.post_attention_layernorm.weight", + # "language_model.mtp.layers.0.transformer_layer.self_attention.linear_qkv.layer_norm_weight": "mtp.layers.0.input_layernorm.weight", + # "language_model.mtp.layers.0.transformer_layer.self_attention.q_layernorm.weight": "mtp.layers.0.self_attn.q_norm.weight", + # "language_model.mtp.layers.0.transformer_layer.self_attention.k_layernorm.weight": "mtp.layers.0.self_attn.k_norm.weight", + # "language_model.mtp.layers.0.transformer_layer.self_attention.linear_proj.weight": "mtp.layers.0.self_attn.o_proj.weight", + # + # Plus QKV, expert MLP, shared expert mappings for MTP layers. + # The exact prefix structure (language_model.mtp.* vs mtp.*) needs verification. + + return MegatronMappingRegistry(*mapping_list) + + +@MegatronModelBridge.register_bridge( + source=_QWEN3_5_DENSE_HF_CLASS_NAME, + target=Qwen3VLModel, + provider=Qwen35VLModelProvider, + model_type="qwen3_5", +) +class Qwen35VLBridge(MegatronModelBridge): + """ + Megatron Bridge for Qwen3.5 Dense Vision-Language Model. + + This bridge handles the conversion between HuggingFace Qwen3.5 dense VL model + and Megatron-Core Qwen3VLModel formats. Unlike the MoE variant, this model uses + a standard dense MLP (gate_proj + up_proj → linear_fc1, down_proj → linear_fc2). + + The weight mappings handle: + - Language model hybrid layers (GDN + standard attention) + - Dense MLP with gated SiLU activation (fused pre-MLP layernorm) + - Vision model weights (no deepstack mergers) + - QK layernorm, zero-centered RMSNorm for GDN output norm + - mRoPE position embeddings + + Architecture (27B): 16 × (3 × GDN + 1 × Attention) = 64 layers + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3.5-27B") + >>> provider = bridge.to_megatron_provider() + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen35VLModelProvider: + """Create a Qwen35VLModelProvider from a HuggingFace pretrained model.""" + hf_config = hf_pretrained.config + text_config = hf_config.text_config + + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) + + vision_config = hf_config.vision_config + vision_config.torch_dtype = provider_kwargs.get("params_dtype", torch.float32) + + provider = Qwen35VLModelProvider(**provider_kwargs) + + # --- Common Qwen3 LLM settings --- + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.add_qkv_bias = getattr(text_config, "attention_bias", False) + provider.add_bias_linear = False + provider.qk_layernorm = True + provider.hidden_dropout = 0.0 + + # --- Qwen3-Next hybrid architecture settings --- + provider.layernorm_zero_centered_gamma = True + provider.attention_output_gate = True + provider.experimental_attention_variant = "gated_delta_net" + provider.linear_attention_freq = getattr(text_config, "full_attention_interval", 4) + provider.rotary_percent = getattr(text_config, "rope_parameters", {}).get("partial_rotary_factor", 0.25) + + # --- GDN (Gated DeltaNet) specific parameters --- + provider.linear_conv_kernel_dim = getattr(text_config, "linear_conv_kernel_dim", 4) + provider.linear_key_head_dim = getattr(text_config, "linear_key_head_dim", 128) + provider.linear_value_head_dim = getattr(text_config, "linear_value_head_dim", 128) + provider.linear_num_key_heads = getattr(text_config, "linear_num_key_heads", 16) + provider.linear_num_value_heads = getattr(text_config, "linear_num_value_heads", 48) + + # --- VL-specific overrides --- + provider.position_embedding_type = "mrope" + provider.vision_config = vision_config + provider.hf_text_config = text_config + provider.head_dim = getattr(text_config, "head_dim", 256) + provider.bos_token_id = getattr(text_config, "bos_token_id", 248045) + provider.eos_token_id = getattr(text_config, "eos_token_id", 248044) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 248053) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 248054) + provider.image_token_id = getattr(hf_config, "image_token_id", 248056) + provider.video_token_id = getattr(hf_config, "video_token_id", 248057) + provider.mrope_section = getattr(text_config, "rope_scaling", {}).get("mrope_section", [11, 11, 10]) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + """ + Return MegatronMappingRegistry for Qwen3.5 dense VL model. + + Key differences from the MoE variant: + - Dense MLP: gate_proj + up_proj fused into linear_fc1, down_proj as linear_fc2 + - Pre-MLP layernorm fused into mlp.linear_fc1 (not a separate pre_mlp_layernorm) + - No MoE router, routed expert MLPs, or shared expert mappings + - No deepstack visual mergers (deepstack_visual_indexes is empty) + """ + + param_mappings = { + # ================================================================= + # Language Model: Embeddings and output + # ================================================================= + "language_model.embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", + "language_model.output_layer.weight": "lm_head.weight", + "language_model.decoder.final_layernorm.weight": "model.language_model.norm.weight", + # ================================================================= + # Language Model: Dense MLP (pre-MLP layernorm fused into linear_fc1) + # ================================================================= + "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.language_model.layers.*.post_attention_layernorm.weight", + "language_model.decoder.layers.*.mlp.linear_fc2.weight": "model.language_model.layers.*.mlp.down_proj.weight", + # ================================================================= + # Language Model: Standard attention layers (Gated Attention) + # ================================================================= + "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.language_model.layers.*.input_layernorm.weight", + "language_model.decoder.layers.*.self_attention.q_layernorm.weight": "model.language_model.layers.*.self_attn.q_norm.weight", + "language_model.decoder.layers.*.self_attention.k_layernorm.weight": "model.language_model.layers.*.self_attn.k_norm.weight", + "language_model.decoder.layers.*.self_attention.linear_proj.weight": "model.language_model.layers.*.self_attn.o_proj.weight", + # ================================================================= + # Language Model: Linear attention (Gated DeltaNet) layers + # ================================================================= + "language_model.decoder.layers.*.self_attention.in_proj.layer_norm_weight": "model.language_model.layers.*.input_layernorm.weight", + "language_model.decoder.layers.*.self_attention.out_proj.weight": "model.language_model.layers.*.linear_attn.out_proj.weight", + "language_model.decoder.layers.*.self_attention.A_log": "model.language_model.layers.*.linear_attn.A_log", + "language_model.decoder.layers.*.self_attention.dt_bias": "model.language_model.layers.*.linear_attn.dt_bias", + # ================================================================= + # Vision Model: Attention + # ================================================================= + "vision_model.decoder.layers.*.self_attention.linear_proj.weight": "model.visual.blocks.*.attn.proj.weight", + "vision_model.decoder.layers.*.self_attention.linear_proj.bias": "model.visual.blocks.*.attn.proj.bias", + # ================================================================= + # Vision Model: MLP + # ================================================================= + "vision_model.decoder.layers.*.mlp.linear_fc1.weight": "model.visual.blocks.*.mlp.linear_fc1.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.bias": "model.visual.blocks.*.mlp.linear_fc1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc2.weight": "model.visual.blocks.*.mlp.linear_fc2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc2.bias": "model.visual.blocks.*.mlp.linear_fc2.bias", + # ================================================================= + # Vision Model: Layer Norms + # ================================================================= + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.visual.blocks.*.norm1.weight", + "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.visual.blocks.*.norm1.bias", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.visual.blocks.*.norm2.weight", + "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.visual.blocks.*.norm2.bias", + # ================================================================= + # Vision Model: Final Merger (no deepstack in dense variant) + # ================================================================= + "vision_model.merger.patch_norm.**": "model.visual.merger.norm.**", + "vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight", + "vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias", + "vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight", + "vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias", + } + + mapping_list = [] + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + AutoMapping.register_module_type("GatedDeltaNet", "column") + + mapping_list.extend( + [ + # ============================================================= + # Language Model: Standard Attention QKV + # ============================================================= + QKVMapping( + megatron_param="language_model.decoder.layers.*.self_attention.linear_qkv.weight", + q="model.language_model.layers.*.self_attn.q_proj.weight", + k="model.language_model.layers.*.self_attn.k_proj.weight", + v="model.language_model.layers.*.self_attn.v_proj.weight", + ), + # ============================================================= + # Language Model: Dense MLP (gated: gate_proj + up_proj → linear_fc1) + # ============================================================= + GatedMLPMapping( + megatron_param="language_model.decoder.layers.*.mlp.linear_fc1.weight", + gate="model.language_model.layers.*.mlp.gate_proj.weight", + up="model.language_model.layers.*.mlp.up_proj.weight", + ), + # ============================================================= + # Language Model: GDN (Gated DeltaNet) specific mappings + # ============================================================= + GDNConv1dMapping( + megatron_param="language_model.decoder.layers.*.self_attention.conv1d.weight", + hf_param="model.language_model.layers.*.linear_attn.conv1d.weight", + ), + GDNLinearMappingSeparate( + megatron_param="language_model.decoder.layers.*.self_attention.in_proj.weight", + qkv="model.language_model.layers.*.linear_attn.in_proj_qkv.weight", + z="model.language_model.layers.*.linear_attn.in_proj_z.weight", + b="model.language_model.layers.*.linear_attn.in_proj_b.weight", + a="model.language_model.layers.*.linear_attn.in_proj_a.weight", + ), + RMSNorm2ZeroCenteredRMSNormMapping( + "language_model.decoder.layers.*.self_attention.out_norm.weight", + "model.language_model.layers.*.linear_attn.norm.weight", + ), + # ============================================================= + # Vision Model: QKV (concatenated format) + # ============================================================= + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", + hf_param="model.visual.blocks.*.attn.qkv.weight", + ), + ConcatenatedQKVMapping( + megatron_param="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", + hf_param="model.visual.blocks.*.attn.qkv.bias", + ), + # ============================================================= + # Vision Model: Patch embedding (replicated across TP ranks) + # ============================================================= + ReplicatedMapping( + megatron_param="vision_model.patch_embed.proj.**", + hf_param="model.visual.patch_embed.proj.**", + ), + ReplicatedMapping( + megatron_param="vision_model.pos_embed.weight", + hf_param="model.visual.pos_embed.weight", + ), + ] + ) + + # TODO: MTP mappings for dense Qwen3.5 VL (mtp_num_hidden_layers=1 in config) + + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py new file mode 100644 index 0000000000..cac4042a24 --- /dev/null +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_provider.py @@ -0,0 +1,446 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Qwen3.5 VL Model Provider configurations for Megatron-Core. + +Qwen3.5 is a family of vision-language models that combine: +- A hybrid Gated DeltaNet (GDN) + Gated Attention language model (like Qwen3-Next) +- A vision encoder (similar to Qwen3-VL) +- Dense MLP or Mixture of Experts (MoE) with shared experts + +This module provides two model providers: + +- ``Qwen35VLModelProvider``: Dense variant (e.g., Qwen3.5-27B) + Reference: https://huggingface.co/Qwen/Qwen3.5-27B + +- ``Qwen35VLMoEModelProvider``: MoE variant (e.g., Qwen3.5-397B-A17B) + Reference: https://huggingface.co/Qwen/Qwen3.5-397B-A17B +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional + +import transformers +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from packaging.version import Version as PkgVersion + + +_TRANSFORMERS_HAS_QWEN3_5_MOE = PkgVersion(transformers.__version__) >= PkgVersion("5.2.0") + +if _TRANSFORMERS_HAS_QWEN3_5_MOE: + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig +else: + Qwen3_5MoeVisionConfig = None # type: ignore[assignment,misc] + +try: + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + + _TRANSFORMERS_HAS_QWEN3_5 = True +except ImportError: + _TRANSFORMERS_HAS_QWEN3_5 = False + Qwen3_5VisionConfig = None # type: ignore[assignment,misc] + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.attention import Qwen3VLSelfAttention +from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel + + +def _check_qwen3_5_available() -> None: + """Raise a clear error if transformers doesn't have qwen3_5 (dense) support.""" + if not _TRANSFORMERS_HAS_QWEN3_5: + raise ImportError( + f"Qwen3.5 VL (dense) requires transformers with qwen3_5 model support, " + f"but found {transformers.__version__}. " + "Please upgrade: pip install --upgrade transformers" + ) + + +def _check_qwen3_5_moe_available() -> None: + """Raise a clear error if transformers doesn't have qwen3_5_moe support.""" + if not _TRANSFORMERS_HAS_QWEN3_5_MOE: + raise ImportError( + f"Qwen3.5 VL (MoE) requires transformers >= 5.2.0, but found {transformers.__version__}. " + "Please upgrade: pip install --upgrade transformers" + ) + + +@dataclass +class Qwen35VLModelProvider(GPTModelProvider): + """ + Model provider for Qwen3.5 VL Dense (Vision-Language) Models. + + Qwen3.5 dense combines a hybrid GDN (Gated DeltaNet) + Gated Attention language + model architecture with a vision encoder (similar to Qwen3-VL) and a standard + dense MLP (no Mixture of Experts). + + Key Architecture Details (27B): + - 64 layers: 16 groups x (3 GDN + 1 Attention) + - Hidden dim: 5120, Intermediate dim: 17408 + - GDN: 16 QK heads, 48 V heads, head_dim=128 + - Gated Attention: 24 Q heads, 4 KV heads, head_dim=256 + - Vision: depth=27, hidden=1152, no deepstack + - mRoPE with sections [11, 11, 10], rope_theta=10,000,000 + - partial_rotary_factor=0.25 + """ + + # ========================================================================= + # Hybrid Architecture (Qwen3-Next style) + # ========================================================================= + transformer_layer_spec: ModuleSpec | Callable[["GPTModelProvider"], ModuleSpec] = ( + get_transformer_block_with_experimental_attention_variant_spec + ) + layernorm_zero_centered_gamma: bool = True + attention_output_gate: bool = True + experimental_attention_variant: str = "gated_delta_net" + linear_attention_freq: int | list[int] = 4 + + # --- Gated DeltaNet (GDN) parameters --- + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_value_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 48 + + # ========================================================================= + # Common LLM parameters + # ========================================================================= + normalization: str = "RMSNorm" + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: int | None = 256 + num_query_groups: int = 4 + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + attention_softmax_in_fp32: bool = True + rotary_base: float = 10000000.0 + rotary_percent: float = 0.25 + seq_length: int = 262144 + + # ========================================================================= + # VL-specific parameters + # ========================================================================= + vision_config: Any = field(default=None) + position_embedding_type: str = "mrope" + mrope_section: List[int] = field(default_factory=lambda: [11, 11, 10]) + apply_rotary_pos_emb_in_fp32: bool = False + + image_token_id: int = 248056 + video_token_id: int = 248057 + vision_start_token_id: int = 248053 + vision_end_token_id: int = 248054 + bos_token_id: int = 248045 + eos_token_id: int = 248044 + + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + patch_size: int = 16 + language_max_sequence_length: int = 2048 + scatter_embedding_sequence_parallel: bool = False + + # ========================================================================= + # Freeze options for fine-tuning + # ========================================================================= + freeze_language_model: bool = False + freeze_vision_model: bool = False + freeze_vision_projection: bool = False + + # ========================================================================= + # Performance + # ========================================================================= + bias_activation_fusion: bool = True + use_hf_vision_model: bool = False + vision_dp_when_cp: bool = False + hetereogenous_dist_checkpoint: bool = True + + mtp_num_layers: Optional[int] = None + + def __post_init__(self): + _check_qwen3_5_available() + if self.vision_config is None: + self.vision_config = Qwen3_5VisionConfig() + if self.num_query_groups < self.tensor_model_parallel_size: + raise ValueError( + f"TP size {self.tensor_model_parallel_size} should be less than or equal to num_query_groups {self.num_query_groups}. Please use a smaller TP size." + ) + super().__post_init__() + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: + """Provide a Qwen3.5 VL dense model instance with vision and language components.""" + language_transformer_config = self + hf_vision_config = self.vision_config + + block_spec = get_transformer_block_with_experimental_attention_variant_spec( + language_transformer_config, + vp_stage=vp_stage, + ) + _patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention) + + model = Qwen3VLModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=block_spec, + vision_transformer_config=hf_vision_config, + pre_process=pre_process, + post_process=post_process, + pg_collection=self._pg_collection, + ) + + if self.freeze_language_model or self.freeze_vision_model or self.freeze_vision_projection: + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Provide just the language model component without vision.""" + return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + + +@dataclass +class Qwen35VLMoEModelProvider(GPTModelProvider): + """ + Model provider for Qwen 3.5 VL (Vision-Language) Models. + + Qwen 3.5 combines a hybrid GDN (Gated DeltaNet) + Gated Attention language model + architecture (like Qwen3-Next) with a vision encoder (similar to Qwen3-VL) and + Mixture of Experts (MoE) with shared experts. + + Key Architecture Details (397B-A17B): + - 60 layers: 15 groups × (3 GDN-MoE + 1 Attention-MoE) + - Hidden dim: 4096, Token Embedding: 248320 + - GDN: 16 QK heads, 64 V heads, head_dim=128 + - Gated Attention: 32 Q heads, 2 KV heads, head_dim=256 + - MoE: 512 experts, 10 routed + 1 shared, expert dim=1024 + - mRoPE with sections [11, 11, 10], rope_theta=10,000,000 + - partial_rotary_factor=0.25 + + Note: num_query_groups corresponds to num_key_value_heads in HF config (for + standard Gated Attention layers). GDN layers have separate head counts. + """ + + # ========================================================================= + # Hybrid Architecture (Qwen3-Next style) + # ========================================================================= + transformer_layer_spec: ModuleSpec | Callable[["GPTModelProvider"], ModuleSpec] = ( + get_transformer_block_with_experimental_attention_variant_spec + ) + layernorm_zero_centered_gamma: bool = True + attention_output_gate: bool = True + experimental_attention_variant: str = "gated_delta_net" + linear_attention_freq: int | list[int] = 4 # 1 standard attention per 4 layers + + # --- Gated DeltaNet (GDN) parameters --- + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_value_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 64 # 64 V heads for GDN in 397B model + + # ========================================================================= + # MoE parameters + # ========================================================================= + num_moe_experts: int = 512 + moe_router_topk: int = 10 # 10 routed experts per token + moe_shared_expert_gate: bool = True + moe_router_dtype: str = "fp32" + moe_router_load_balancing_type: str = "global_aux_loss" + moe_router_pre_softmax: bool = False + moe_grouped_gemm: bool = True + moe_token_dispatcher_type: str = "alltoall" + moe_permute_fusion: bool = True + moe_aux_loss_coeff: float = 1e-3 + + # ========================================================================= + # Common LLM parameters + # ========================================================================= + normalization: str = "RMSNorm" + gated_linear_unit: bool = True + add_bias_linear: bool = False + add_qkv_bias: bool = False + qk_layernorm: bool = True + kv_channels: int | None = 256 # head_dim for standard Gated Attention + num_query_groups: int = 2 # KV heads for standard Gated Attention + hidden_dropout: float = 0.0 + attention_dropout: float = 0.0 + attention_softmax_in_fp32: bool = True + rotary_base: float = 10000000.0 # rope_theta from HF config + rotary_percent: float = 0.25 # partial_rotary_factor from HF config + seq_length: int = 262144 # 262K native context length + + # ========================================================================= + # VL-specific parameters + # ========================================================================= + + vision_config: Any = field(default=None) + + # Position embedding: Qwen3.5 uses multimodal rope (mRoPE) + position_embedding_type: str = "mrope" + # Qwen3.5 mRoPE section is [11, 11, 10] (different from Qwen3-VL's [24, 20, 20]) + # because partial_rotary_factor=0.25, so RoPE dim = 256*0.25 = 64, with sections [11,11,10] + # for [temporal, height, width] summing to 32 (half of 64 rotary dim). + mrope_section: List[int] = field(default_factory=lambda: [11, 11, 10]) + apply_rotary_pos_emb_in_fp32: bool = False + + # Vision-Language token IDs + image_token_id: int = 248056 + video_token_id: int = 248057 + vision_start_token_id: int = 248053 + vision_end_token_id: int = 248054 + bos_token_id: int = 248045 + eos_token_id: int = 248046 + + # Vision model settings + spatial_merge_size: int = 2 + temporal_patch_size: int = 2 + patch_size: int = 16 + language_max_sequence_length: int = 2048 + + scatter_embedding_sequence_parallel: bool = False + + # ========================================================================= + # Freeze options for fine-tuning + # ========================================================================= + freeze_language_model: bool = False + freeze_vision_model: bool = False + freeze_vision_projection: bool = False + + # ========================================================================= + # Performance + # ========================================================================= + bias_activation_fusion: bool = True + use_hf_vision_model: bool = False + vision_dp_when_cp: bool = False + + # Heterogeneous dist checkpoint (needed for hybrid architecture) + hetereogenous_dist_checkpoint: bool = True + + # TODO: MTP (Multi-Token Prediction) support for VL context. + # Qwen3.5 model card states "MTP: trained with multi-steps" but it's unclear + # how MTP interacts with the vision encoder in VL mode. + mtp_num_layers: Optional[int] = None + + def __post_init__(self): + _check_qwen3_5_moe_available() + if self.vision_config is None: + self.vision_config = Qwen3_5MoeVisionConfig() + if self.num_query_groups < self.tensor_model_parallel_size: + raise ValueError( + f"TP size {self.tensor_model_parallel_size} should be less than or equal to num_query_groups {self.num_query_groups}. Please use a smaller TP size." + ) + super().__post_init__() + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: + """Provide a Qwen3.5 VL model instance with vision and language components. + + Qwen3.5 uses a hybrid architecture (GDN + standard attention). The key + challenge is that Qwen3VLModel.__init__ does:: + + language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention + + which assumes a single ModuleSpec and patches ALL layers uniformly. + For Qwen3.5, only the standard attention layers (every 4th layer) should + get the Qwen3VLSelfAttention override; GDN layers must be left alone. + + Solution: build the hybrid TransformerBlockSubmodules spec, selectively + patch only the standard attention layer specs, then pass it to + Qwen3VLModel. Because GPTModel → TransformerBlock already accepts + TransformerBlockSubmodules, we just need to bypass the uniform patch + in Qwen3VLModel.__init__ by calling MegatronModule.__init__ directly + and constructing the internals ourselves. + """ + language_transformer_config = self + hf_vision_config = self.vision_config + + # Build hybrid block spec: produces TransformerBlockSubmodules with + # per-layer specs (GDN layers get GatedDeltaNet, attention layers get + # standard SelfAttention + MoE). + block_spec = get_transformer_block_with_experimental_attention_variant_spec( + language_transformer_config, + vp_stage=vp_stage, + ) + + # Selectively patch only the standard (full) attention layer specs + # with Qwen3VLSelfAttention for mRoPE support. GDN layers are left as-is. + _patch_standard_attention_specs(block_spec, Qwen3VLSelfAttention) + + # Qwen3VLModel expects a single ModuleSpec and does a uniform patch on + # line 87. We pass the block spec instead – GPTModel/TransformerBlock + # already handles TransformerBlockSubmodules natively. The uniform patch + # line will fail on a TransformerBlockSubmodules, so we bypass it by + # passing the spec through and letting the model ignore the single-spec + # assumption. We pass the block_spec as the layer spec; Qwen3VLGPTModel + # forwards it to TransformerBlock which handles both types. + model = Qwen3VLModel( + language_transformer_config=language_transformer_config, + language_transformer_layer_spec=block_spec, + vision_transformer_config=hf_vision_config, + pre_process=pre_process, + post_process=post_process, + pg_collection=self._pg_collection, + ) + + # Apply freeze options if any are enabled for fine-tuning + if self.freeze_language_model or self.freeze_vision_model or self.freeze_vision_projection: + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """Provide just the language model component without vision.""" + return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + + +def _patch_standard_attention_specs( + block_spec: TransformerBlockSubmodules, + attention_cls, +) -> None: + """Selectively replace the self_attention module on standard attention layer specs. + + In a hybrid block spec, each layer spec has a different self_attention submodule: + - Standard attention layers have a ``SelfAttention``-like module. + - GDN layers have a ``GatedDeltaNet``-like module. + + This function patches only the standard attention layers with *attention_cls* + (e.g. ``Qwen3VLSelfAttention`` for mRoPE support), leaving GDN layers unchanged. + + Detection heuristic: GDN layer specs have ``GatedDeltaNet`` (or similar) as the + self_attention module, which does NOT have a ``linear_qkv`` submodule. Standard + attention specs DO have ``linear_qkv``. We use this to distinguish them. + """ + from megatron.core.transformer.attention import SelfAttention + + for layer_spec in block_spec.layer_specs: + attn_spec = layer_spec.submodules.self_attention + # Standard attention specs use SelfAttention (or a subclass) as the module + # and have linear_qkv in their submodules. GDN specs use GatedDeltaNet. + if attn_spec.module is SelfAttention or ( + isinstance(attn_spec.module, type) and issubclass(attn_spec.module, SelfAttention) + ): + attn_spec.module = attention_cls diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py index 20d1bc371d..895bfed520 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py @@ -474,6 +474,9 @@ class ExpertMLPDownProjMapping(AutoMapping): (standard) HF expert weight layouts are handled transparently. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: global_expert_number = extract_expert_number_from_param(self.megatron_param) expert_weight = hf_weights[global_expert_number] if hf_weights.ndim >= 3 else hf_weights diff --git a/tests/functional_tests/L2_Launch_models_qwen35_vl.sh b/tests/functional_tests/L2_Launch_models_qwen35_vl.sh new file mode 100755 index 0000000000..b88a382e74 --- /dev/null +++ b/tests/functional_tests/L2_Launch_models_qwen35_vl.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail + +export CUDA_VISIBLE_DEVICES="0,1" + +uv run coverage run --data-file=/opt/Megatron-Bridge/.coverage --source=/opt/Megatron-Bridge/ --parallel-mode -m pytest \ + -o log_cli=true -o log_cli_level=INFO -v -s -x -m "not pleasefixme" --tb=short -rA \ + tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py +coverage combine -q diff --git a/tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py b/tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py new file mode 100644 index 0000000000..df3d233568 --- /dev/null +++ b/tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py @@ -0,0 +1,462 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functional tests for Qwen3.5 VL HF ↔ Megatron roundtrip conversion. + +Qwen3.5 uses a hybrid Gated DeltaNet (GDN) + Gated Attention architecture. +The full_attention_interval=4 means every 4th layer is standard attention, +so num_hidden_layers must be a multiple of 4. + +Run dense test: + uv run python -m torch.distributed.run --nproc_per_node=2 -m pytest \ + tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py::TestQwen35VLConversion -v -s + +Run MoE test: + uv run python -m torch.distributed.run --nproc_per_node=2 -m pytest \ + tests/functional_tests/models/qwen_vl/test_qwen35_vl_conversion.py::TestQwen35VLMoEConversion -v -s +""" + +import json +import re +import subprocess +from pathlib import Path + +import pytest +import torch + + +try: + from transformers import Qwen3_5ForConditionalGeneration + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config + + _HAS_QWEN3_5 = True +except ImportError: + _HAS_QWEN3_5 = False + +try: + from transformers import Qwen3_5MoeForConditionalGeneration + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeConfig + + _HAS_QWEN3_5_MOE = True +except ImportError: + _HAS_QWEN3_5_MOE = False + + +# --------------------------------------------------------------------------- +# Tiny dense config (Qwen3.5 dense style, ~small param count for fast tests) +# Mirrors the structure of the real Qwen3.5-27B config: +# https://huggingface.co/Qwen/Qwen3.5-27B/blob/main/config.json +# num_hidden_layers must be a multiple of full_attention_interval (4) +# --------------------------------------------------------------------------- +HF_QWEN35_VL_TOY_MODEL_CONFIG = { + "architectures": ["Qwen3_5ForConditionalGeneration"], + "image_token_id": 248056, + "model_type": "qwen3_5", + "torch_dtype": "bfloat16", + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "eos_token_id": 248044, + "full_attention_interval": 4, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 256, + "initializer_range": 0.02, + "intermediate_size": 512, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 32, + "linear_num_key_heads": 4, + "linear_num_value_heads": 4, + "linear_value_head_dim": 32, + "max_position_embeddings": 32768, + "model_type": "qwen3_5_text", + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "torch_dtype": "bfloat16", + "use_cache": True, + "vocab_size": 2048, + "rope_parameters": { + "rope_type": "default", + "partial_rotary_factor": 0.25, + "rope_theta": 10000000.0, + "mrope_section": [8, 8, 8], + }, + }, + "tie_word_embeddings": False, + "video_token_id": 248057, + "vision_config": { + "deepstack_visual_indexes": [], + "depth": 1, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 256, + "in_channels": 3, + "intermediate_size": 512, + "num_heads": 4, + "num_position_embeddings": 2304, + "out_hidden_size": 256, + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + "vision_end_token_id": 248054, + "vision_start_token_id": 248053, +} + + +@pytest.mark.skipif(not _HAS_QWEN3_5, reason="transformers does not have Qwen3.5 (dense) support") +class TestQwen35VLConversion: + """Test Qwen3.5 VL dense model conversion from HuggingFace to Megatron.""" + + @pytest.fixture(scope="class") + def qwen35_vl_toy_model_path(self, tmp_path_factory): + """Create and save a dense Qwen3.5 VL toy model.""" + temp_dir = tmp_path_factory.mktemp("qwen35_vl_toy_model") + model_dir = temp_dir / "qwen35_vl_toy" + + config = Qwen3_5Config(**HF_QWEN35_VL_TOY_MODEL_CONFIG) + config.torch_dtype = torch.bfloat16 + + model = Qwen3_5ForConditionalGeneration(config) + model = model.to(dtype=torch.bfloat16) + + try: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-27B") + tokenizer.save_pretrained(model_dir) + except Exception: + tokenizer_config = { + "tokenizer_class": "Qwen2Tokenizer", + "vocab_size": 248320, + } + model_dir.mkdir(parents=True, exist_ok=True) + with open(model_dir / "tokenizer_config.json", "w") as f: + json.dump(tokenizer_config, f, indent=2) + + model.save_pretrained(model_dir, safe_serialization=True) + + return str(model_dir) + + def test_toy_model_creation(self, qwen35_vl_toy_model_path): + """Verify the toy model was created correctly.""" + model_path = Path(qwen35_vl_toy_model_path) + assert model_path.exists() + + config_file = model_path / "config.json" + assert config_file.exists() + + weights_file = model_path / "model.safetensors" + if not weights_file.exists(): + weights_file = model_path / "model.safetensors.index.json" + if not weights_file.exists(): + weights_file = model_path / "pytorch_model.bin" + assert weights_file.exists() + + with open(config_file) as f: + config_data = json.load(f) + + assert config_data["model_type"] == "qwen3_5" + assert "text_config" in config_data + assert "vision_config" in config_data + text_cfg = config_data["text_config"] + assert text_cfg["hidden_size"] == 256 + assert text_cfg["num_hidden_layers"] == 4 + + _ = Qwen3_5ForConditionalGeneration.from_pretrained( + qwen35_vl_toy_model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=False, + ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize("tp,pp", [(2, 1)]) + def test_qwen35_vl_conversion(self, qwen35_vl_toy_model_path, tmp_path, tp, pp): + """Test dense Qwen3.5 VL conversion with TP parallelism.""" + test_output_dir = tmp_path / "qwen35_vl_test" + test_output_dir.mkdir(exist_ok=True) + + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "--nnodes=1", + "-m", + "coverage", + "run", + "--data-file=/opt/Megatron-Bridge/.coverage", + "--source=/opt/Megatron-Bridge/", + "--parallel-mode", + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + qwen35_vl_toy_model_path, + "--output-dir", + str(test_output_dir), + "--tp", + str(tp), + "--pp", + str(pp), + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent + ) + + if result.returncode != 0: + print(f"STDOUT: {result.stdout}") + print(f"STDERR: {result.stderr}") + assert False, f"Qwen3.5 VL dense conversion failed with return code {result.returncode}" + + model_name = Path(qwen35_vl_toy_model_path).name + converted_model_dir = test_output_dir / model_name + assert converted_model_dir.exists() + + config_file = converted_model_dir / "config.json" + assert config_file.exists() + + with open(config_file) as f: + saved_config = json.load(f) + + assert saved_config["model_type"] == "qwen3_5" + assert "text_config" in saved_config + assert "vision_config" in saved_config + + +# --------------------------------------------------------------------------- +# Tiny MoE config (Qwen3.5 MoE style) +# Mirrors the structure of the real Qwen3.5-35B-A3B config: +# https://huggingface.co/Qwen/Qwen3.5-35B-A3B/blob/main/config.json +# --------------------------------------------------------------------------- +HF_QWEN35_VL_MOE_TOY_MODEL_CONFIG = { + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "image_token_id": 248056, + "model_type": "qwen3_5_moe", + "torch_dtype": "bfloat16", + "text_config": { + "attention_bias": False, + "attention_dropout": 0.0, + "eos_token_id": 248046, + "full_attention_interval": 4, + "head_dim": 64, + "hidden_act": "silu", + "hidden_size": 256, + "initializer_range": 0.02, + "intermediate_size": 512, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 32, + "linear_num_key_heads": 4, + "linear_num_value_heads": 4, + "linear_value_head_dim": 32, + "max_position_embeddings": 32768, + "model_type": "qwen3_5_moe_text", + "moe_intermediate_size": 256, + "num_attention_heads": 4, + "num_experts": 4, + "num_experts_per_tok": 2, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "shared_expert_intermediate_size": 512, + "torch_dtype": "bfloat16", + "use_cache": True, + "vocab_size": 2048, + "rope_parameters": { + "rope_type": "default", + "partial_rotary_factor": 0.25, + "rope_theta": 10000000.0, + "mrope_section": [8, 8, 8], + }, + }, + "tie_word_embeddings": False, + "video_token_id": 248057, + "vision_config": { + "deepstack_visual_indexes": [], + "depth": 1, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 256, + "in_channels": 3, + "intermediate_size": 512, + "num_heads": 4, + "num_position_embeddings": 2304, + "out_hidden_size": 256, + "patch_size": 14, + "spatial_merge_size": 2, + "temporal_patch_size": 2, + }, + "vision_end_token_id": 248054, + "vision_start_token_id": 248053, +} + + +def _fuse_moe_expert_weights(model_dir: Path, num_experts: int) -> None: + """Fuse per-expert HF weights into the 3-D format the bridge expects. + + HuggingFace's Qwen3.5-MoE model class stores each expert as a separate + ``nn.Linear`` (e.g. ``experts.0.gate_proj.weight``), but the published + checkpoints ship with fused tensors (``experts.gate_up_proj`` of shape + ``[num_experts, 2*intermediate, hidden]`` and ``experts.down_proj`` of shape + ``[num_experts, hidden, intermediate]``). This helper rewrites the saved + safetensors file so the toy model matches the real-checkpoint layout. + """ + from safetensors.torch import load_file, save_file + + weights_path = model_dir / "model.safetensors" + state_dict = load_file(str(weights_path)) + + expert_re = re.compile( + r"^(model\.language_model\.layers\.\d+\.mlp\.experts)" + r"\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$" + ) + + # Collect per-expert tensors grouped by layer prefix + layers: dict = {} + keys_to_remove: list[str] = [] + for key in state_dict: + m = expert_re.match(key) + if m: + prefix, idx, proj = m.group(1), int(m.group(2)), m.group(3) + layers.setdefault(prefix, {}).setdefault(idx, {})[proj] = state_dict[key] + keys_to_remove.append(key) + + if not keys_to_remove: + return + + new_state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove} + + for prefix, experts in layers.items(): + gate_up = torch.stack( + [torch.cat([experts[i]["gate_proj"], experts[i]["up_proj"]], dim=0) for i in range(num_experts)], + dim=0, + ) + down = torch.stack([experts[i]["down_proj"] for i in range(num_experts)], dim=0) + new_state_dict[f"{prefix}.gate_up_proj"] = gate_up + new_state_dict[f"{prefix}.down_proj"] = down + + save_file(new_state_dict, str(weights_path)) + + +@pytest.mark.skipif(not _HAS_QWEN3_5_MOE, reason="transformers does not have Qwen3.5 MoE support") +class TestQwen35VLMoEConversion: + """Test Qwen3.5 VL MoE model conversion.""" + + @pytest.fixture(scope="class") + def qwen35_vl_moe_toy_model_path(self, tmp_path_factory): + """Create and save a MoE Qwen3.5 VL toy model.""" + temp_dir = tmp_path_factory.mktemp("qwen35_vl_moe_toy_model") + model_dir = temp_dir / "qwen35_vl_moe_toy" + + config = Qwen3_5MoeConfig(**HF_QWEN35_VL_MOE_TOY_MODEL_CONFIG) + config.torch_dtype = torch.bfloat16 + + model = Qwen3_5MoeForConditionalGeneration(config) + model = model.to(dtype=torch.bfloat16) + + try: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-35B-A3B") + tokenizer.save_pretrained(model_dir) + except Exception: + tokenizer_config = { + "tokenizer_class": "Qwen2Tokenizer", + "vocab_size": 248320, + } + model_dir.mkdir(parents=True, exist_ok=True) + with open(model_dir / "tokenizer_config.json", "w") as f: + json.dump(tokenizer_config, f, indent=2) + + model.save_pretrained(model_dir, safe_serialization=True) + + _fuse_moe_expert_weights(model_dir, num_experts=config.text_config.num_experts) + + return str(model_dir) + + def test_moe_toy_model_creation(self, qwen35_vl_moe_toy_model_path): + """Verify the MoE toy model was created correctly.""" + model_path = Path(qwen35_vl_moe_toy_model_path) + assert model_path.exists() + + config_file = model_path / "config.json" + assert config_file.exists() + + with open(config_file) as f: + config_data = json.load(f) + + assert config_data["model_type"] == "qwen3_5_moe" + assert "text_config" in config_data + text_cfg = config_data["text_config"] + assert text_cfg["num_experts"] == 4 + assert text_cfg["full_attention_interval"] == 4 + + _ = Qwen3_5MoeForConditionalGeneration.from_pretrained( + qwen35_vl_moe_toy_model_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=False, + ) + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize("tp,pp", [(2, 1)]) + def test_moe_conversion(self, qwen35_vl_moe_toy_model_path, tmp_path, tp, pp): + """Test MoE Qwen3.5 VL conversion with TP parallelism.""" + test_output_dir = tmp_path / "qwen35_vl_moe_test" + test_output_dir.mkdir(exist_ok=True) + + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "--nnodes=1", + "-m", + "coverage", + "run", + "--data-file=/opt/Megatron-Bridge/.coverage", + "--source=/opt/Megatron-Bridge/", + "--parallel-mode", + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + qwen35_vl_moe_toy_model_path, + "--output-dir", + str(test_output_dir), + "--tp", + str(tp), + "--pp", + str(pp), + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent + ) + + if result.returncode != 0: + print(f"STDOUT: {result.stdout}") + print(f"STDERR: {result.stderr}") + assert False, f"Qwen3.5 VL MoE conversion failed with return code {result.returncode}" + + model_name = Path(qwen35_vl_moe_toy_model_path).name + converted_model_dir = test_output_dir / model_name + assert converted_model_dir.exists() + + config_file = converted_model_dir / "config.json" + assert config_file.exists() + + with open(config_file) as f: + saved_config = json.load(f) + + assert saved_config["model_type"] == "qwen3_5_moe" + assert "text_config" in saved_config + assert saved_config["text_config"]["num_experts"] == 4 diff --git a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py new file mode 100644 index 0000000000..d9d8a52521 --- /dev/null +++ b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_bridge.py @@ -0,0 +1,385 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import Mock, patch + +import pytest +import torch + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import Qwen35VLBridge, Qwen35VLMoEBridge +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + _TRANSFORMERS_HAS_QWEN3_5, + _TRANSFORMERS_HAS_QWEN3_5_MOE, + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, +) + + +pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") + + +def _make_dense_text_config(): + """Create a mock text config matching Qwen3.5-27B dense architecture.""" + cfg = Mock() + cfg.num_hidden_layers = 64 + cfg.hidden_size = 5120 + cfg.intermediate_size = 17408 + cfg.num_attention_heads = 24 + cfg.num_key_value_heads = 4 + cfg.initializer_range = 0.02 + cfg.rms_norm_eps = 1e-6 + cfg.vocab_size = 248320 + cfg.max_position_embeddings = 262144 + cfg.rope_theta = 10000000.0 + cfg.tie_word_embeddings = False + cfg.hidden_act = "silu" + cfg.attention_bias = False + cfg.head_dim = 256 + cfg.full_attention_interval = 4 + cfg.rope_parameters = {"partial_rotary_factor": 0.25, "rope_theta": 10000000.0} + cfg.rope_scaling = {"mrope_section": [11, 11, 10]} + cfg.linear_conv_kernel_dim = 4 + cfg.linear_key_head_dim = 128 + cfg.linear_value_head_dim = 128 + cfg.linear_num_key_heads = 16 + cfg.linear_num_value_heads = 48 + cfg.bos_token_id = 248045 + cfg.eos_token_id = 248044 + cfg.q_lora_rank = None + cfg.kv_lora_rank = None + cfg.qk_nope_head_dim = None + cfg.qk_rope_head_dim = None + cfg.v_head_dim = None + cfg.num_nextn_predict_layers = None + return cfg + + +def _make_moe_text_config(): + """Create a mock text config matching Qwen3.5-397B-A17B MoE architecture.""" + cfg = Mock() + cfg.num_hidden_layers = 60 + cfg.hidden_size = 4096 + cfg.intermediate_size = 1024 + cfg.num_attention_heads = 32 + cfg.num_key_value_heads = 2 + cfg.initializer_range = 0.02 + cfg.rms_norm_eps = 1e-6 + cfg.vocab_size = 248320 + cfg.max_position_embeddings = 262144 + cfg.rope_theta = 10000000.0 + cfg.tie_word_embeddings = False + cfg.hidden_act = "silu" + cfg.attention_bias = False + cfg.head_dim = 256 + cfg.full_attention_interval = 4 + cfg.rope_parameters = {"partial_rotary_factor": 0.25, "rope_theta": 10000000.0} + cfg.rope_scaling = {"mrope_section": [11, 11, 10]} + cfg.linear_conv_kernel_dim = 4 + cfg.linear_key_head_dim = 128 + cfg.linear_value_head_dim = 128 + cfg.linear_num_key_heads = 16 + cfg.linear_num_value_heads = 64 + cfg.moe_intermediate_size = 1024 + cfg.num_experts = 512 + cfg.num_experts_per_tok = 10 + cfg.shared_expert_intermediate_size = 4096 + cfg.bos_token_id = 248045 + cfg.eos_token_id = 248046 + cfg.q_lora_rank = None + cfg.kv_lora_rank = None + cfg.qk_nope_head_dim = None + cfg.qk_rope_head_dim = None + cfg.v_head_dim = None + cfg.num_nextn_predict_layers = None + return cfg + + +def _make_vision_config(): + """Create a minimal mock vision config.""" + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + + return Qwen3_5VisionConfig() + + +def _make_mock_pretrained(text_config, vision_config, tie_word_embeddings=False): + pretrained = Mock(spec=PreTrainedVLM) + config = Mock() + config.text_config = text_config + config.vision_config = vision_config + config.tie_word_embeddings = tie_word_embeddings + config.vision_start_token_id = 248053 + config.vision_end_token_id = 248054 + config.image_token_id = 248056 + config.video_token_id = 248057 + config.audio_token_id = 248076 + pretrained.config = config + return pretrained + + +# ===================================================================== +# Tests for Qwen35VLBridge (Dense) +# ===================================================================== + + +class TestQwen35VLBridgeInitialization: + def test_bridge_initialization(self): + bridge = Qwen35VLBridge() + assert isinstance(bridge, Qwen35VLBridge) + + def test_bridge_has_required_methods(self): + bridge = Qwen35VLBridge() + assert hasattr(bridge, "provider_bridge") and callable(bridge.provider_bridge) + assert hasattr(bridge, "mapping_registry") and callable(bridge.mapping_registry) + + +class TestQwen35VLBridgeProviderBridge: + @pytest.fixture + def bridge(self): + return Qwen35VLBridge() + + @pytest.fixture + def mock_pretrained(self): + return _make_mock_pretrained(_make_dense_text_config(), _make_vision_config()) + + def test_provider_bridge_returns_correct_type(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert isinstance(provider, Qwen35VLModelProvider) + + def test_provider_bridge_basic_config(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.num_layers == 64 + assert provider.hidden_size == 5120 + assert provider.ffn_hidden_size == 17408 + assert provider.num_attention_heads == 24 + assert provider.num_query_groups == 4 + assert provider.vocab_size == 248320 + + def test_provider_bridge_hybrid_architecture(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.layernorm_zero_centered_gamma is True + assert provider.attention_output_gate is True + assert provider.experimental_attention_variant == "gated_delta_net" + assert provider.linear_attention_freq == 4 + + def test_provider_bridge_gdn_params(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.linear_conv_kernel_dim == 4 + assert provider.linear_key_head_dim == 128 + assert provider.linear_value_head_dim == 128 + assert provider.linear_num_key_heads == 16 + assert provider.linear_num_value_heads == 48 + + def test_provider_bridge_vl_config(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.position_embedding_type == "mrope" + assert provider.mrope_section == [11, 11, 10] + assert provider.head_dim == 256 + assert provider.rotary_percent == 0.25 + + def test_provider_bridge_token_ids(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.bos_token_id == 248045 + assert provider.eos_token_id == 248044 + assert provider.vision_start_token_id == 248053 + assert provider.vision_end_token_id == 248054 + assert provider.image_token_id == 248056 + assert provider.video_token_id == 248057 + + def test_provider_bridge_common_settings(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.normalization == "RMSNorm" + assert provider.gated_linear_unit is True + assert provider.add_qkv_bias is False + assert provider.add_bias_linear is False + assert provider.qk_layernorm is True + assert provider.hidden_dropout == 0.0 + + @patch.object(Qwen35VLBridge, "dtype_from_hf") + def test_provider_bridge_dtype_handling(self, mock_dtype, bridge, mock_pretrained): + mock_dtype.return_value = torch.bfloat16 + provider = bridge.provider_bridge(mock_pretrained) + assert provider.bf16 is True + assert provider.params_dtype == torch.bfloat16 + + def test_provider_bridge_tied_embeddings(self, bridge): + text_config = _make_dense_text_config() + text_config.tie_word_embeddings = True + pretrained = _make_mock_pretrained(text_config, _make_vision_config()) + provider = bridge.provider_bridge(pretrained) + assert provider.share_embeddings_and_output_weights is True + + +class TestQwen35VLBridgeMappingRegistry: + @pytest.fixture + def bridge(self): + return Qwen35VLBridge() + + def _get_mapping_names(self, registry): + names = [] + for mapping in registry.mappings: + if hasattr(mapping, "megatron_param"): + names.append(str(getattr(mapping, "megatron_param"))) + hf = getattr(mapping, "hf_param", None) + if isinstance(hf, dict): + names.extend([str(v) for v in hf.values()]) + elif isinstance(hf, str): + names.append(hf) + return names + + def test_mapping_registry_type(self, bridge): + registry = bridge.mapping_registry() + assert isinstance(registry, MegatronMappingRegistry) + assert len(registry.mappings) > 0 + + def test_mapping_registry_has_embeddings(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("embed_tokens" in n or "word_embeddings" in n for n in names) + + def test_mapping_registry_has_output_layer(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("lm_head" in n or "output_layer" in n for n in names) + + def test_mapping_registry_has_gdn_mappings(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("in_proj" in n for n in names), "Should contain GDN in_proj mappings" + assert any("out_proj" in n for n in names), "Should contain GDN out_proj mappings" + assert any("A_log" in n for n in names), "Should contain GDN A_log mappings" + assert any("conv1d" in n for n in names), "Should contain GDN conv1d mappings" + assert any("out_norm" in n or "linear_attn.norm" in n for n in names) + + def test_mapping_registry_has_dense_mlp(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("gate_proj" in n for n in names), "Should contain gate_proj for dense MLP" + assert any("up_proj" in n for n in names), "Should contain up_proj for dense MLP" + assert any("down_proj" in n for n in names), "Should contain down_proj for dense MLP" + + def test_mapping_registry_has_no_moe(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert not any("router" in n or "experts" in n for n in names), "Dense model should not have MoE mappings" + + def test_mapping_registry_has_vision_params(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("visual" in n or "vision_model" in n for n in names) + + def test_mapping_registry_has_qkv(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("linear_qkv" in n for n in names) + + def test_mapping_registry_has_vision_patch_embed(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("patch_embed" in n for n in names) + + +# ===================================================================== +# Tests for Qwen35VLMoEBridge +# ===================================================================== + + +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +class TestQwen35VLMoEBridgeInitialization: + def test_bridge_initialization(self): + bridge = Qwen35VLMoEBridge() + assert isinstance(bridge, Qwen35VLMoEBridge) + + +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +class TestQwen35VLMoEBridgeProviderBridge: + @pytest.fixture + def bridge(self): + return Qwen35VLMoEBridge() + + @pytest.fixture + def mock_pretrained(self): + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig + + return _make_mock_pretrained(_make_moe_text_config(), Qwen3_5MoeVisionConfig()) + + def test_provider_bridge_returns_correct_type(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert isinstance(provider, Qwen35VLMoEModelProvider) + + def test_provider_bridge_basic_config(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.num_layers == 60 + assert provider.hidden_size == 4096 + assert provider.num_attention_heads == 32 + assert provider.num_query_groups == 2 + assert provider.vocab_size == 248320 + + def test_provider_bridge_moe_config(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.num_moe_experts == 512 + assert provider.moe_router_topk == 10 + assert provider.moe_ffn_hidden_size == 1024 + assert provider.moe_shared_expert_gate is True + assert provider.moe_grouped_gemm is True + + def test_provider_bridge_hybrid_architecture(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.experimental_attention_variant == "gated_delta_net" + assert provider.linear_attention_freq == 4 + assert provider.layernorm_zero_centered_gamma is True + + def test_provider_bridge_gdn_params(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.linear_num_value_heads == 64 + assert provider.linear_key_head_dim == 128 + assert provider.linear_value_head_dim == 128 + + def test_provider_bridge_token_ids(self, bridge, mock_pretrained): + provider = bridge.provider_bridge(mock_pretrained) + assert provider.bos_token_id == 248045 + assert provider.eos_token_id == 248046 + assert provider.image_token_id == 248056 + + +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +class TestQwen35VLMoEBridgeMappingRegistry: + @pytest.fixture + def bridge(self): + return Qwen35VLMoEBridge() + + def _get_mapping_names(self, registry): + names = [] + for mapping in registry.mappings: + if hasattr(mapping, "megatron_param"): + names.append(str(getattr(mapping, "megatron_param"))) + hf = getattr(mapping, "hf_param", None) + if isinstance(hf, dict): + names.extend([str(v) for v in hf.values()]) + elif isinstance(hf, str): + names.append(hf) + return names + + def test_mapping_registry_type(self, bridge): + registry = bridge.mapping_registry() + assert isinstance(registry, MegatronMappingRegistry) + + def test_mapping_registry_has_moe_mappings(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("router" in n or "gate.weight" in n for n in names), "Should contain MoE router" + assert any("experts" in n for n in names), "Should contain expert MLPs" + assert any("shared_expert" in n for n in names), "Should contain shared experts" + + def test_mapping_registry_has_gdn_mappings(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("in_proj" in n for n in names) + assert any("A_log" in n for n in names) + assert any("conv1d" in n for n in names) + + def test_mapping_registry_has_vision_params(self, bridge): + names = self._get_mapping_names(bridge.mapping_registry()) + assert any("visual" in n or "vision_model" in n for n in names) diff --git a/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py new file mode 100644 index 0000000000..fb399d9044 --- /dev/null +++ b/tests/unit_tests/models/qwen_vl/test_qwen35_vl_provider.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( + _TRANSFORMERS_HAS_QWEN3_5, + _TRANSFORMERS_HAS_QWEN3_5_MOE, + Qwen35VLModelProvider, + Qwen35VLMoEModelProvider, +) + + +pytestmark = pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5, reason="transformers does not have qwen3_5 support") + + +class TestQwen35VLModelProvider: + """Tests for the dense Qwen3.5 VL model provider.""" + + def test_initialization_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.num_layers == 64 + assert provider.hidden_size == 5120 + assert provider.num_attention_heads == 24 + + def test_hybrid_architecture_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.layernorm_zero_centered_gamma is True + assert provider.attention_output_gate is True + assert provider.experimental_attention_variant == "gated_delta_net" + assert provider.linear_attention_freq == 4 + + def test_gdn_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.linear_conv_kernel_dim == 4 + assert provider.linear_key_head_dim == 128 + assert provider.linear_value_head_dim == 128 + assert provider.linear_num_key_heads == 16 + assert provider.linear_num_value_heads == 48 + + def test_vl_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.position_embedding_type == "mrope" + assert provider.mrope_section == [11, 11, 10] + assert provider.image_token_id == 248056 + assert provider.video_token_id == 248057 + assert provider.vision_start_token_id == 248053 + assert provider.vision_end_token_id == 248054 + assert provider.bos_token_id == 248045 + assert provider.eos_token_id == 248044 + + def test_common_llm_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.normalization == "RMSNorm" + assert provider.gated_linear_unit is True + assert provider.add_bias_linear is False + assert provider.add_qkv_bias is False + assert provider.qk_layernorm is True + assert provider.kv_channels == 256 + assert provider.num_query_groups == 4 + assert provider.hidden_dropout == 0.0 + assert provider.rotary_base == 10000000.0 + assert provider.rotary_percent == 0.25 + + def test_freeze_options_defaults(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert provider.freeze_language_model is False + assert provider.freeze_vision_model is False + assert provider.freeze_vision_projection is False + + def test_freeze_options_custom(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + freeze_language_model=True, + freeze_vision_model=True, + ) + assert provider.freeze_language_model is True + assert provider.freeze_vision_model is True + + def test_custom_mrope_section(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + mrope_section=[8, 12, 12], + ) + assert provider.mrope_section == [8, 12, 12] + + def test_vision_config_default_type(self): + from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5VisionConfig + + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert isinstance(provider.vision_config, Qwen3_5VisionConfig) + + def test_inherits_from_gpt_provider(self): + assert issubclass(Qwen35VLModelProvider, GPTModelProvider) + + def test_provide_methods_exist(self): + provider = Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + ) + assert hasattr(provider, "provide") and callable(provider.provide) + assert hasattr(provider, "provide_language_model") and callable(provider.provide_language_model) + + def test_tp_validation(self): + with pytest.raises(ValueError, match="TP size"): + Qwen35VLModelProvider( + num_layers=64, + hidden_size=5120, + num_attention_heads=24, + num_query_groups=2, + tensor_model_parallel_size=4, + ) + + +@pytest.mark.skipif(not _TRANSFORMERS_HAS_QWEN3_5_MOE, reason="transformers does not have qwen3_5_moe support") +class TestQwen35VLMoEModelProvider: + """Tests for the MoE Qwen3.5 VL model provider.""" + + def test_initialization_defaults(self): + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert provider.num_layers == 60 + assert provider.hidden_size == 4096 + assert provider.num_attention_heads == 32 + + def test_moe_defaults(self): + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert provider.num_moe_experts == 512 + assert provider.moe_router_topk == 10 + assert provider.moe_shared_expert_gate is True + assert provider.moe_grouped_gemm is True + assert provider.moe_router_load_balancing_type == "global_aux_loss" + assert provider.moe_router_pre_softmax is False + assert provider.moe_token_dispatcher_type == "alltoall" + + def test_hybrid_architecture_defaults(self): + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert provider.experimental_attention_variant == "gated_delta_net" + assert provider.linear_attention_freq == 4 + assert provider.layernorm_zero_centered_gamma is True + assert provider.attention_output_gate is True + + def test_gdn_defaults(self): + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert provider.linear_num_value_heads == 64 + assert provider.linear_num_key_heads == 16 + assert provider.linear_key_head_dim == 128 + assert provider.linear_value_head_dim == 128 + + def test_vl_defaults(self): + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert provider.position_embedding_type == "mrope" + assert provider.mrope_section == [11, 11, 10] + assert provider.bos_token_id == 248045 + assert provider.eos_token_id == 248046 + + def test_inherits_from_gpt_provider(self): + assert issubclass(Qwen35VLMoEModelProvider, GPTModelProvider) + + def test_vision_config_default_type(self): + from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeVisionConfig + + provider = Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + ) + assert isinstance(provider.vision_config, Qwen3_5MoeVisionConfig) + + def test_tp_validation(self): + with pytest.raises(ValueError, match="TP size"): + Qwen35VLMoEModelProvider( + num_layers=60, + hidden_size=4096, + num_attention_heads=32, + num_query_groups=2, + tensor_model_parallel_size=4, + )