From e9fc2367e254710803f4fb374cde31bdfdff70eb Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 28 Nov 2025 07:37:48 +0000 Subject: [PATCH 1/6] remove qwen2-vl Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/models/__init__.py | 6 +- vllm_ascend/patch/worker/patch_qwen2_5_vl.py | 210 ++++++++++++++++++- 2 files changed, 212 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 31eae8d7cbe..508745c06d1 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,9 +2,9 @@ def register_model(): - ModelRegistry.register_model( - "Qwen2VLForConditionalGeneration", - "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") + # ModelRegistry.register_model( + # "Qwen2VLForConditionalGeneration", + # "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") ModelRegistry.register_model( "Qwen3VLMoeForConditionalGeneration", diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 27f08751bff..1a521835dae 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -24,17 +24,27 @@ import torch_npu from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \ Qwen2_5_VLVisionConfig +from transformers.models.qwen2_vl.configuration_qwen2_vl import \ + Qwen2VLVisionConfig from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.common import ( + apply_rotary_emb_torch, dispatch_rotary_emb_function) from vllm.model_executor.models.qwen2_5_vl import ( Qwen2_5_VisionAttention, Qwen2_5_VisionBlock, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionPatchMerger, Qwen2_5_VisionTransformer, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs) +from vllm.model_executor.models.qwen2_vl import (Qwen2VisionAttention, + Qwen2VisionBlock, + Qwen2VisionPatchEmbed, + Qwen2VisionPatchMerger, + Qwen2VisionTransformer) from vllm.model_executor.models.utils import cast_overflow_tensors from vllm.model_executor.models.vision import ( get_vit_attn_backend, run_dp_sharded_mrope_vision_model) @@ -132,6 +142,190 @@ def forward( return output +class AscendQwen2VisionBlock(nn.Module): + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + max_seqlen: int | None = None, # Only used for Flash Attention + seqlens: list[int] | None = None, # Only used for xFormers + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + x = x + self.mlp(self.norm2(x)) + return x + + +class AscendQwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + use_data_parallel: bool = False, + attn_backend_override: AttentionBackendEnum | None = None, + ) -> None: + nn.Module.__init__(self) + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.use_data_parallel = use_data_parallel + self.out_hidden_size = vision_config.hidden_size + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = get_rope( + head_size=head_dim, + rotary_dim=head_dim // 2, + max_position=8192, + base=10000.0, + is_neox_style=True, + ) + + self.blocks = nn.ModuleList([ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}", + use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, + ) for layer_idx in range(depth) + ]) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger", + use_data_parallel=use_data_parallel, + ) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, + ) + + if (self.attn_backend != AttentionBackendEnum.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype())): + self.attn_backend = AttentionBackendEnum.FLASH_ATTN + + def rot_pos_emb( + self, + grid_thw: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: + pos_ids = [] + max_grid_size = 0 + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = (hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + wpos_ids = (wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten()) + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + max_grid_size = max(max_grid_size, h, w) + pos_ids = torch.cat(pos_ids, dim=0) + + # Use pre-computed cos_sin_cache from RotaryEmbedding + cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) + + cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) + cos_w = cos[pos_ids[:, 1]] + sin_h = sin[pos_ids[:, 0]] + sin_w = sin[pos_ids[:, 1]] + + cos_combined = torch.cat([cos_h, cos_w], dim=-1) + sin_combined = torch.cat([sin_h, sin_w], dim=-1) + return cos_combined, sin_combined + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor | list[list[int]], + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + # compute position embedding + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb( + grid_thw_list) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + # transformers + x = x.unsqueeze(1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + x = blk( + x, + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + # adapter + x = self.merger(x) + + return x + + class AscendQwen2_5_VisionBlock(nn.Module): def forward( @@ -486,7 +680,16 @@ def _process_video_input( return video_embeds.split(sizes) +def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> torch.Tensor: + rotary_emb_function = dispatch_rotary_emb_function( + default=partial(apply_rotary_emb_torch, is_neox_style=True)) + output = rotary_emb_function(t, cos, sin).type_as(t) + return output + + # NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm. +Qwen2VisionAttention.forward = AscendQwen2_5_VisionAttention.forward Qwen2_5_VisionAttention.forward = AscendQwen2_5_VisionAttention.forward # NOTE: These will be removed after https://github.com/vllm-project/vllm/pull/29388 is merged. @@ -494,8 +697,13 @@ def _process_video_input( Qwen2_5_VLForConditionalGeneration._process_video_input = AscendQwen2_5_VLForConditionalGeneration._process_video_input # NOTE: These will be removed after vllm-ascend is aligned with vllm latest main. +Qwen2VisionBlock.forward = AscendQwen2VisionBlock.forward +Qwen2VisionTransformer.__init__ = AscendQwen2VisionTransformer.__init__ +Qwen2VisionTransformer.rot_pos_emb = AscendQwen2VisionTransformer.rot_pos_emb +Qwen2VisionTransformer.forward = AscendQwen2VisionTransformer.forward Qwen2_5_VisionBlock.forward = AscendQwen2_5_VisionBlock.forward Qwen2_5_VisionTransformer.__init__ = AscendQwen2_5_VisionTransformer.__init__ Qwen2_5_VisionTransformer.rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer.rotary_pos_emb_thw Qwen2_5_VisionTransformer.get_rope_by_thw = AscendQwen2_5_VisionTransformer.get_rope_by_thw Qwen2_5_VisionTransformer.forward = AscendQwen2_5_VisionTransformer.forward +apply_rotary_pos_emb_vision = _apply_rotary_pos_emb_vision From 4983dda7e303635945ede0e10c5a9840a90faa5c Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 28 Nov 2025 08:00:59 +0000 Subject: [PATCH 2/6] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/models/__init__.py | 4 - vllm_ascend/models/qwen2_vl.py | 373 --------------------------------- vllm_ascend/utils.py | 3 +- 3 files changed, 1 insertion(+), 379 deletions(-) delete mode 100644 vllm_ascend/models/qwen2_vl.py diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 508745c06d1..5bbc42bf468 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -2,10 +2,6 @@ def register_model(): - # ModelRegistry.register_model( - # "Qwen2VLForConditionalGeneration", - # "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") - ModelRegistry.register_model( "Qwen3VLMoeForConditionalGeneration", "vllm_ascend.models.qwen3_vl:AscendQwen3VLMoeForConditionalGeneration") diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py deleted file mode 100644 index f24f9823648..00000000000 --- a/vllm_ascend/models/qwen2_vl.py +++ /dev/null @@ -1,373 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Adapted from vllm/model_executor/models/qwen2_vl.py -# This file is a part of the vllm-ascend project. - -from collections.abc import Iterable -from functools import partial -from typing import Callable, Optional, Set, Tuple, Type - -import torch -import torch.nn as nn -import torch_npu -from einops import rearrange -from transformers.models.qwen2_vl.configuration_qwen2_vl import \ - Qwen2VLVisionConfig -from vllm.config import VllmConfig -from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import QuickGELU -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2_vl import ( - Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed, - Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder, - Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, - Qwen2VLProcessingInfo) -from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.models.vision import conv3d_to_linear_weight -from vllm.multimodal import MULTIMODAL_REGISTRY - -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, is_enable_nz - -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight - - -class AscendQwen2VisionAttention(Qwen2VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.cu_seqlens = None - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - - self.cu_seqlens = cu_seqlens - - # [s, b, c] --> [s, b, 3 * head * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = [ - rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) - ] - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=self.cu_seqlens, - scale_value=self.origin_hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - - -class AscendQwen2VisionBlock(Qwen2VisionBlock): - - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float, - act_layer: Type[nn.Module] = QuickGELU, - norm_layer: Optional[Callable[[int], nn.Module]] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__(dim, num_heads, mlp_ratio, act_layer, norm_layer, - quant_config, prefix) - self.attn = AscendQwen2VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - - x = x + self.mlp(self.norm2(x)) - return x - - -class AscendQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.matmul( - self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1)) - return x - - -class AscendQwen2VisionTransformer(Qwen2VisionTransformer): - - def __init__( - self, - vision_config: Qwen2VLVisionConfig, - norm_eps: float = 1e-6, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - interleaved=False, - ) -> None: - super().__init__(vision_config, norm_eps, quant_config, prefix) - - self.interleaved = interleaved - self.enable_pad = False - self.depth = vision_config.depth - self.hidden_size = vision_config.embed_dim - self.num_heads = vision_config.num_heads - self.patch_embed = AscendQwen2VisionPatchEmbed( - patch_size=vision_config.patch_size, - temporal_patch_size=vision_config.temporal_patch_size, - in_channels=vision_config.in_channels, - embed_dim=vision_config.embed_dim, - ) - - self.blocks = nn.ModuleList([ - AscendQwen2VisionBlock(dim=self.embed_dim, - num_heads=self.num_heads, - mlp_ratio=vision_config.mlp_ratio, - norm_layer=partial(nn.LayerNorm, - eps=norm_eps), - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") - for layer_idx in range(vision_config.depth) - ]) - - self.hidden_size_per_attention_head = dist_utils.divide( - self.hidden_size, self.num_heads) - - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.enable_pad = True - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 - self.half_pad_hidden_size_per_attention_head = ( - MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - if self.enable_pad: - cos = torch.nn.functional.pad( - cos, (0, self.half_pad_hidden_size_per_attention_head)) - sin = torch.nn.functional.pad( - sin, (0, self.half_pad_hidden_size_per_attention_head)) - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def pad_qkv_bias(self, bias): - first_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, :self.half_origin_hidden_size_per_attention_head] - second_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, self.half_origin_hidden_size_per_attention_head:] - first_half_padded = torch.nn.functional.pad( - first_half, (0, self.half_pad_hidden_size_per_attention_head)) - second_half_padded = torch.nn.functional.pad( - second_half, (0, self.half_pad_hidden_size_per_attention_head)) - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) - bias_final = bias_padded.reshape(-1) - return bias_final - - def pad_qkv_weight(self, data): - qkv_weight_first_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] - qkv_weight_second_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] - - qkv_weight_first_half_padded = torch.nn.functional.pad( - qkv_weight_first_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_second_half_padded = torch.nn.functional.pad( - qkv_weight_second_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_padded = torch.cat( - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], - dim=2) - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - - if is_enable_nz(): - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( - qkv_weight_final) - qkv_weight_final_copy = torch_npu.npu_format_cast( - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) - return qkv_weight_final_copy - - return qkv_weight_final - - def pad_proj_weight(self, data): - out_weight = torch.nn.functional.pad( - data.reshape(self.hidden_size, -1, - self.half_origin_hidden_size_per_attention_head), - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( - self.hidden_size, -1) - - if is_enable_nz(): - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) - out_weight_copy = torch_npu.npu_format_cast( - out_weight_copy, ACL_FORMAT_FRACTAL_ND) - return out_weight_copy - - return out_weight - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - - for name, loaded_weight in weights: - if name.endswith("patch_embed.proj.weight"): - loaded_weight = conv3d_to_linear_weight(loaded_weight) - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - if ("attn.proj.weight" in name) and self.enable_pad: - param.data = self.pad_proj_weight(param.data) - if ("attn.qkv.weight" in name) and self.enable_pad: - param.data = self.pad_qkv_weight(param.data) - if ("attn.qkv.bias" in name) and self.enable_pad: - param.data = self.pad_qkv_bias(param.data) - loaded_params.add(name) - return loaded_params - - def forward( - self, - x: torch.Tensor, - grid_thw: torch.Tensor, - ) -> torch.Tensor: - grid_thw = torch.tensor(grid_thw, dtype=torch.int32) - # compute cu_seqlens and avoid cumsum to fit operator unpadFA - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], - grid_thw[:, - 0]).cpu().to(torch.int32) - - # patchify - x = x.to(device=self.device, dtype=self.dtype) - x = self.patch_embed(x) - - # compute position embedding - rotary_pos_emb = self.rot_pos_emb(grid_thw) - cos, sin = self.cal_cos_sin(rotary_pos_emb) - - x = x.unsqueeze(1) - for blk in self.blocks: - x = blk(x, cu_seqlens=cu_seqlens, cos=cos, sin=sin) - - # adapter - x = self.merger(x) - return x - - -@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor, - info=Qwen2VLProcessingInfo, - dummy_inputs=Qwen2VLDummyInputsBuilder) -class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - self.visual = AscendQwen2VisionTransformer( - self.config.vision_config, - norm_eps=getattr(self.config, "rms_norm_eps", 1e-6), - quant_config=vllm_config.quant_config, - prefix=maybe_prefix(prefix, "visual"), - ) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0a74bcbfdcf..438683d553f 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -707,8 +707,7 @@ class AscendDeviceType(Enum): def _init_ascend_device_type(): global _ascend_device_type - from vllm_ascend import _build_info # type: ignore - _ascend_device_type = AscendDeviceType[_build_info.__device_type__] + _ascend_device_type = AscendDeviceType._910B def check_ascend_device_type(): From b773cbfce84a393f3adac10ed8e71129c6f41a51 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 28 Nov 2025 08:04:11 +0000 Subject: [PATCH 3/6] update Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 438683d553f..0a74bcbfdcf 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -707,7 +707,8 @@ class AscendDeviceType(Enum): def _init_ascend_device_type(): global _ascend_device_type - _ascend_device_type = AscendDeviceType._910B + from vllm_ascend import _build_info # type: ignore + _ascend_device_type = AscendDeviceType[_build_info.__device_type__] def check_ascend_device_type(): From 942ea93c65b87ec73bdf70d6f7d67657e421bbaf Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 28 Nov 2025 08:49:33 +0000 Subject: [PATCH 4/6] fix lint Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/patch/worker/patch_qwen2_5_vl.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index 1a521835dae..b5f1a2e32e7 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -270,10 +270,11 @@ def rot_pos_emb( # Use pre-computed cos_sin_cache from RotaryEmbedding cos, sin = self.rotary_pos_emb.get_cos_sin(max_grid_size) - cos_h = cos[pos_ids[:, 0]] # (num_tokens, rotary_dim // 2) - cos_w = cos[pos_ids[:, 1]] - sin_h = sin[pos_ids[:, 0]] - sin_w = sin[pos_ids[:, 1]] + # (num_tokens, rotary_dim // 2) + cos_h = cos[pos_ids[:, 0]] # type: ignore + cos_w = cos[pos_ids[:, 1]] # type: ignore + sin_h = sin[pos_ids[:, 0]] # type: ignore + sin_w = sin[pos_ids[:, 1]] # type: ignore cos_combined = torch.cat([cos_h, cos_w], dim=-1) sin_combined = torch.cat([sin_h, sin_w], dim=-1) From 519ae3431cd4d32c5cdd371b25953126241c4fca Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Fri, 28 Nov 2025 09:08:01 +0000 Subject: [PATCH 5/6] remove ut Signed-off-by: shen-shanshan <467638484@qq.com> --- tests/ut/models/test_qwen2_vl.py | 200 ------------------------------- 1 file changed, 200 deletions(-) delete mode 100644 tests/ut/models/test_qwen2_vl.py diff --git a/tests/ut/models/test_qwen2_vl.py b/tests/ut/models/test_qwen2_vl.py deleted file mode 100644 index d62b8594bae..00000000000 --- a/tests/ut/models/test_qwen2_vl.py +++ /dev/null @@ -1,200 +0,0 @@ -import pytest -import torch -from pytest_mock import MockerFixture -from vllm.model_executor.layers.activation import QuickGELU - -from tests.ut.base import PytestBase -from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention, - AscendQwen2VisionBlock) - - -class TestAscendQwen2VisionAttention(PytestBase): - - def init_attention( - self, - mocker, - embed_dim=1000, - num_heads=10, - projection_size=100, - quant_config=None, - prefix="", - ): - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__") - - attention = AscendQwen2VisionAttention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - args, kwargs = mocker_attn.call_args - assert args == (embed_dim, num_heads, projection_size, None, "") - assert not kwargs - attention.num_attention_heads_per_partition = num_heads - return attention - - def test_attn_init_should_normal(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 10 - projection_size = 100 - quant_config = None - prefix = "" - vit = self.init_attention( - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - mocker=mocker, - ) - assert vit.hidden_size_per_attention_head == 10 - - def test_attn_init_should_raise_error(self, mocker: MockerFixture): - embed_dim = 1000 - num_heads = 7 - projection_size = 100 - quant_config = None - prefix = "" - with pytest.raises(AssertionError): - # projection_size should divided by num heads - self.init_attention( - mocker=mocker, - embed_dim=embed_dim, - num_heads=num_heads, - projection_size=projection_size, - quant_config=quant_config, - prefix=prefix, - ) - - def test_attn_forward(self, mocker: MockerFixture): - attention = self.init_attention(mocker=mocker) - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - - qkv = lambda x: (x, 0) # noqa - split_qkv = lambda x: [ #noqa - torch.rand((100, 3, 10, 128)) for i in range(3) - ] # noqa - npu_rotary_mul = lambda q, cos, sin: q # noqa - _npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa - proj = lambda x: (x, 0) # noqa - - mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv) - mocker_split_qkv = mocker.patch.object( - attention, - "split_qkv", - side_effect=split_qkv, - ) - mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul", - side_effect=npu_rotary_mul) - mocker_npu_flash_attention_unpad = mocker.patch( - "torch_npu._npu_flash_attention_unpad", - side_effect=_npu_flash_attention_unpad, - ) - mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj) - attention.__dict__["qkv"] = mocker_qkv - attention.__dict__["split_qkv"] = mocker_split_qkv - attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul - attention.__dict__["_npu_flash_attention_unpad"] = ( - mocker_npu_flash_attention_unpad) - attention.__dict__["proj"] = mocker_proj - - output = attention.forward( - x=x, - cu_seqlens=cu_seqlens, - cos=cos, - sin=sin, - ) - qkv_args, qkv_kwargs = mocker_qkv.call_args - assert qkv_args == (x, ) - assert not qkv_kwargs - - split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args - assert split_qkv_args == (x, ) - assert not split_qkv_kwargs - - npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args - assert npu_rotary_mul_args[1:] == (cos, sin) - assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128]) - assert not npu_rotary_mul_kwargs - - assert output.shape == torch.Size([100, 3, 1280]) - - -class TestAscendQwen2VisionBlock(PytestBase): - - def init_vision_block( - self, - mocker, - dim=100, - num_heads=10, - mlp_ratio=0.5, - ): - mocker_vit = mocker.patch( - "vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__", - return_value=None, - ) - - mocker_attn = mocker.patch( - "vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__", - return_value=None, - ) - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - vision_block = AscendQwen2VisionBlock( - dim=dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - ) - args, kwargs = mocker_vit.call_args - assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "") - assert not kwargs - - args1, kwargs1 = mocker_attn.call_args - assert not args1 - assert kwargs1 == { - "embed_dim": dim, - "num_heads": num_heads, - "projection_size": dim, - "quant_config": None, - "prefix": ".attn", - } - return vision_block - - def test_init_vision_block_should_normal( - self, - mocker: MockerFixture, - ): - vision_block = self.init_vision_block(mocker) - assert isinstance(vision_block, AscendQwen2VisionBlock) - - def test_vision_block_forward(self, mocker: MockerFixture): - x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d - cu_seqlens = torch.tensor([10, 50, 100]) - cos = torch.rand((1, 100, 1, 128)) - sin = torch.rand((1, 100, 1, 128)) - vision_block = self.init_vision_block(mocker) - mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x) - mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x) - vision_block.__dict__["attn"] = mocker_attn - vision_block.__dict__["mlp"] = mocker_mlp - - output = vision_block.forward(x.clone(), cu_seqlens, cos, sin) - - _, attn_kwargs = mocker_attn.call_args - assert attn_kwargs == { - "cu_seqlens": cu_seqlens, - "cos": cos, - "sin": sin, - } - - assert torch.all(x * 3 == output) From a50aaa78dd989ba37cb4a16a3a0be35b1b2c20ee Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Sat, 29 Nov 2025 06:56:27 +0000 Subject: [PATCH 6/6] fix rope Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/patch/worker/patch_qwen2_5_vl.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py index b5f1a2e32e7..464c62830b6 100644 --- a/vllm_ascend/patch/worker/patch_qwen2_5_vl.py +++ b/vllm_ascend/patch/worker/patch_qwen2_5_vl.py @@ -84,18 +84,8 @@ def forward( # Convert cumulative tensor to intervals and move it to cpu. cu_seqlens = torch.diff(cu_seqlens).to("cpu") - cos = rotary_pos_emb_cos - sin = rotary_pos_emb_sin - cos = einops.rearrange( - torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2, - ) - sin = einops.rearrange( - torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2, - ) + cos = torch.cat((rotary_pos_emb_cos, rotary_pos_emb_cos), dim=-1) + sin = torch.cat((rotary_pos_emb_sin, rotary_pos_emb_sin), dim=-1) cos = cos.reshape(1, -1, 1, self.hidden_size_per_attention_head) sin = sin.reshape(1, -1, 1, self.hidden_size_per_attention_head) q = torch_npu.npu_rotary_mul(q, cos, sin)