From 31361e54d68df04495a247fd954cd1c1e1d608e6 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 4 Dec 2025 03:57:11 +0000 Subject: [PATCH 01/40] [Graph] [Fusion] Fusion slice and qknorm operator Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 6 +- .../compilation/graph_fusion_pass_manager.py | 6 +- vllm_ascend/compilation/matcher_utils.py | 61 ++++ .../compilation/passes/qknorm_fusion_pass.py | 210 ++++++++++++ vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/triton/linear/qk_rmsnorm.py | 301 ++++++++++++++++++ 6 files changed, 582 insertions(+), 3 deletions(-) create mode 100644 vllm_ascend/compilation/matcher_utils.py create mode 100644 vllm_ascend/compilation/passes/qknorm_fusion_pass.py create mode 100644 vllm_ascend/ops/triton/linear/qk_rmsnorm.py diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 48d23a7c49b..2aeefad2fde 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -231,7 +231,7 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, fuse_norm_quant: bool = True, **kwargs): + def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, **kwargs): """ Initialize the configuration. @@ -239,11 +239,13 @@ def __init__(self, fuse_norm_quant: bool = True, **kwargs): fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. When set to True, the system will optimize norm and quant operations. Default: True + fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. + Default: True **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.fuse_norm_quant = fuse_norm_quant - # Add more compilation related configs here as needed + self.fuse_qknorm_rope = fuse_qknorm_rope class XliteGraphConfig: diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 2922869453b..09e21a6f070 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -50,4 +50,8 @@ def configure(self, config: VllmConfig): from .passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass self.passes.append(AddRMSNormQuantFusionPass(config)) - # Add more passes here as needed + + if self.ascend_compilation_config.get("fuse_qknorm", + True): + from .passes.qknorm_fusion_pass import QKNormFusionPass + self.passes.append(QKNormFusionPass(config)) diff --git a/vllm_ascend/compilation/matcher_utils.py b/vllm_ascend/compilation/matcher_utils.py new file mode 100644 index 00000000000..ecd536842f6 --- /dev/null +++ b/vllm_ascend/compilation/matcher_utils.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC, abstractmethod +import torch +import torch_npu + + +class MatcherCustomOp(ABC): + def __init__(self, epsilon: float): + self.epsilon = epsilon + + @abstractmethod + def forward(self, *args, **kws): + pass + + def __call__(self, *args, **kws): + return self.forward(*args, **kws) + + +class MatcherAscendRMSNorm(MatcherCustomOp): + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + x, residual = torch_npu.npu_rms_norm( + input, weight, self.epsilon + ) + return x + + +class MatcherAscendRMSNormWithBias(MatcherCustomOp): + + def forward( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + ) -> torch.Tensor: + x, residual = torch_npu.npu_rms_norm( + input, weight, self.epsilon + ) + x.add_(bias) + return x + diff --git a/vllm_ascend/compilation/passes/qknorm_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_fusion_pass.py new file mode 100644 index 00000000000..fccb2c1a18c --- /dev/null +++ b/vllm_ascend/compilation/passes/qknorm_fusion_pass.py @@ -0,0 +1,210 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +import torch +import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import PatternMatcherPass +from vllm.config import get_current_vllm_config +from vllm.attention.layer import Attention +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm_ascend.compilation.matcher_utils import MatcherAscendRMSNorm, MatcherAscendRMSNormWithBias + +class QKNormFusionPattern: + + def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + vllm_config = get_current_vllm_config() + self.device = vllm_config.device_config.device if vllm_config.device_config else None + self.rmsnorm_matcher = MatcherAscendRMSNorm(eps) + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + return [qkv, q_weight, k_weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) + q_flat = q_normed_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) + k_flat = k_normed_by_head.view(k.shape) + return q_flat, k_flat, v + + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor + ): + results = torch.ops.vllm.qk_rmsnorm( + input = qkv, + q_weight = q_weight, + k_weight = k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=None, + k_bias=None, + ) + return results + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass + ) + + +class QKNormFusionPatternWithBias: + + def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + vllm_config = get_current_vllm_config() + self.device = vllm_config.device_config.device if vllm_config.device_config else None + self.rmsnorm_matcher = MatcherAscendRMSNormWithBias(eps) + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") + q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + + return [qkv, q_weight, k_weight, q_bias, k_bias] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight, q_bias) + q_flat = q_normed_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight, k_bias) + k_flat = k_normed_by_head.view(k.shape) + return q_flat, k_flat, v + + def replacement( + qkv: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_bias: torch.Tensor, + k_bias: torch.Tensor + ): + results = torch.ops.vllm.qk_rmsnorm( + input = qkv, + q_weight = q_weight, + k_weight = k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=q_bias, + k_bias=k_bias, + ) + return results + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + pm.fwd_only, + pm_pass + ) + + +class QKNormFusionPass(VllmInductorPass): + """ + A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator. + """ + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( + pass_name="qknorm_fusion_pass") + + dtype = vllm_config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logging.info("QKNorm fusion not enabled: unsupported dtype %s", + dtype) + return + + # use one attn layer to get meta (such as head_dim) for QkNormFusionPattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config( + vllm_config, Attention) + if len(attn_layers) == 0: + logging.info( + "QK Norm fusion enabled, but no Attention layers were discovered." + ) + return + layer = next(iter(attn_layers.values())) + + for epsilon in [1e-5, 1e-6]: + QKNormFusionPattern( + head_dim=layer.head_size, + num_heads = layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register(self.pattern_match_passes) + + + def __call__(self, graph: torch.fx.Graph): + self.begin() + print("Graph before QK Norm Fusion Pass:") + print(graph.graph) + self.matched_count = self.pattern_match_passes.apply(graph) + print("Graph after QK Norm Fusion Pass:") + print(graph.graph) + logging.info("Fused %s QKNorm patterns", self.matched_count) + self.end_and_log() + + def is_applicable(self, runtime_shape): + """ + Check if the pass is applicable for the current configuration. + """ + return True diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index e121f2a442c..f9d29af6b63 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -21,6 +21,7 @@ import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa +import vllm_ascend.ops.triton.linear.qk_rmsnorm # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) diff --git a/vllm_ascend/ops/triton/linear/qk_rmsnorm.py b/vllm_ascend/ops/triton/linear/qk_rmsnorm.py new file mode 100644 index 00000000000..ba0cd54c948 --- /dev/null +++ b/vllm_ascend/ops/triton/linear/qk_rmsnorm.py @@ -0,0 +1,301 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from typing import Optional + +import torch +import triton +import triton.language as tl +import triton.runtime.driver as driver +from vllm.utils.torch_utils import direct_register_custom_op + + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.jit +def qk_rmsnorm_triton_kernel( + input_ptr, + q_ptr, + k_ptr, + v_ptr, + q_weight_ptr, + k_weight_ptr, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + row_pid = tl.program_id(0) + col_pid = tl.program_id(1) + row_step = tl.num_programs(0) + weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + output_offset = row_pid * q_hidden_size + input_offset_step = row_step * total_hidden_size + output_offset_step = row_step * q_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) + valid_mask = col_indices < q_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + Q_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std + output_values = normalized_values * weight_values + tl.store(q_ptr + output_offset + col_indices, + output_values.to(tl.bfloat16).reshape(Q_BLOCK_SIZE), + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + q_hidden_size + output_offset = row_pid * kv_hidden_size + output_offset_step = row_step * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + KV_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std + output_values = normalized_values * weight_values + tl.store(k_ptr + output_offset + col_indices, + output_values.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size + output_offset = row_pid * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0) + tl.store(v_ptr + output_offset + col_indices, + input_values, + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + +@triton.jit +def qk_rmsnorm_bias_triton_kernel( + input_ptr, + q_ptr, + k_ptr, + v_ptr, + q_weight_ptr, + q_bias_ptr, + k_weight_ptr, + k_bias_ptr, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + row_pid = tl.program_id(0) + col_pid = tl.program_id(1) + row_step = tl.num_programs(0) + + # q norm + weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + output_offset = row_pid * q_hidden_size + input_offset_step = row_step * total_hidden_size + output_offset_step = row_step * q_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) + valid_mask = col_indices < q_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + Q_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + output_values = normalized_values * weight_values + bias_values + tl.store(q_ptr + output_offset + col_indices, + output_values.to(tl.bfloat16).reshape(Q_BLOCK_SIZE), + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + + # k norm + weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + q_hidden_size + output_offset = row_pid * kv_hidden_size + output_offset_step = row_step * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + KV_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + output_values = normalized_values * weight_values + bias_values + tl.store(k_ptr + output_offset + col_indices, + output_values.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + + # v copy + input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size + output_offset = row_pid * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0) + tl.store(v_ptr + output_offset + col_indices, + input_values, + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + + +num_core = get_npu_properties()["num_vectorcore"] + + +def qk_rmsnorm_impl( + input: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor] = None, + k_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) + assert KV_BLOCK_SIZE == head_dim + assert q_hidden_size % kv_hidden_size == 0 + + Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim + batch_size = input.shape[0] + total_hidden_size = q_hidden_size + kv_hidden_size * 2 + + q_output = torch.empty(batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype) + k_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + v_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + + n_cols = kv_hidden_size // KV_BLOCK_SIZE + assert num_core % n_cols == 0 + n_rows = num_core // n_cols + + if q_bias is None: + qk_rmsnorm_triton_kernel[(n_rows, n_cols)]( + input, + q_output, + k_output, + v_output, + q_weight, + k_weight, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + head_dim, + ) + else: + qk_rmsnorm_bias_triton_kernel[(n_rows, n_cols)]( + input, + q_output, + k_output, + v_output, + q_weight, + q_bias, + k_weight, + k_bias, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + head_dim, + ) + return q_output, k_output, v_output + + +def qk_rmsnorm_impl_fake( + input: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor] = None, + k_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # fake impl for shape inference + batch_size = input.shape[0] + q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) + k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) + v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) + return q_output, k_output, v_output + + +direct_register_custom_op(op_name="qk_rmsnorm", + op_func=qk_rmsnorm_impl, + fake_impl=qk_rmsnorm_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") From 70ec1126ca411831cbeb69554e31c66a834c21fa Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 4 Dec 2025 06:20:16 +0000 Subject: [PATCH 02/40] adapt bias is not none Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 6 +- .../compilation/graph_fusion_pass_manager.py | 5 +- vllm_ascend/compilation/matcher_utils.py | 61 ------ .../compilation/passes/qknorm_fusion_pass.py | 180 +++++++++--------- vllm_ascend/ops/__init__.py | 2 +- vllm_ascend/ops/triton/linear/qk_rmsnorm.py | 16 +- 6 files changed, 112 insertions(+), 158 deletions(-) delete mode 100644 vllm_ascend/compilation/matcher_utils.py diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2aeefad2fde..716743e00a0 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -231,7 +231,11 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, **kwargs): + def __init__(self, + fuse_norm_quant: bool = True, + fuse_qknorm_rope: bool = True, + **kwargs + ): """ Initialize the configuration. diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 09e21a6f070..c968f0b3e67 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -50,8 +50,7 @@ def configure(self, config: VllmConfig): from .passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass self.passes.append(AddRMSNormQuantFusionPass(config)) - - if self.ascend_compilation_config.get("fuse_qknorm", - True): + + if self.ascend_compilation_config.get("fuse_qknorm", True): from .passes.qknorm_fusion_pass import QKNormFusionPass self.passes.append(QKNormFusionPass(config)) diff --git a/vllm_ascend/compilation/matcher_utils.py b/vllm_ascend/compilation/matcher_utils.py deleted file mode 100644 index ecd536842f6..00000000000 --- a/vllm_ascend/compilation/matcher_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from abc import ABC, abstractmethod -import torch -import torch_npu - - -class MatcherCustomOp(ABC): - def __init__(self, epsilon: float): - self.epsilon = epsilon - - @abstractmethod - def forward(self, *args, **kws): - pass - - def __call__(self, *args, **kws): - return self.forward(*args, **kws) - - -class MatcherAscendRMSNorm(MatcherCustomOp): - - def forward( - self, - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - x, residual = torch_npu.npu_rms_norm( - input, weight, self.epsilon - ) - return x - - -class MatcherAscendRMSNormWithBias(MatcherCustomOp): - - def forward( - self, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - ) -> torch.Tensor: - x, residual = torch_npu.npu_rms_norm( - input, weight, self.epsilon - ) - x.add_(bias) - return x - diff --git a/vllm_ascend/compilation/passes/qknorm_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_fusion_pass.py index fccb2c1a18c..b6c0f1f8cd0 100644 --- a/vllm_ascend/compilation/passes/qknorm_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_fusion_pass.py @@ -20,11 +20,11 @@ import torch import torch._inductor.pattern_matcher as pm from torch._inductor.pattern_matcher import PatternMatcherPass -from vllm.config import get_current_vllm_config from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm_ascend.compilation.matcher_utils import MatcherAscendRMSNorm, MatcherAscendRMSNormWithBias +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) + class QKNormFusionPattern: @@ -37,40 +37,46 @@ def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): self.eps = eps vllm_config = get_current_vllm_config() self.device = vllm_config.device_config.device if vllm_config.device_config else None - self.rmsnorm_matcher = MatcherAscendRMSNorm(eps) def get_inputs(self): T = 5 - qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") - q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + qkv = torch.empty(T, + self.q_size + 2 * self.kv_size, + dtype=torch.bfloat16, + device="npu") + q_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + k_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") return [qkv, q_weight, k_weight] - + def register(self, pm_pass: PatternMatcherPass): - def pattern( - qkv: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor - ): - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight) - q_flat = q_normed_by_head.view(q.shape) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight) - k_flat = k_normed_by_head.view(k.shape) + + def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, + self.eps) + q_flat = q_norm_out.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, + self.eps) + k_flat = k_norm_out.view(k.shape) return q_flat, k_flat, v - def replacement( - qkv: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor - ): + def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor): results = torch.ops.vllm.qk_rmsnorm( - input = qkv, - q_weight = q_weight, - k_weight = k_weight, + input=qkv, + q_weight=q_weight, + k_weight=k_weight, q_hidden_size=self.q_size, kv_hidden_size=self.kv_size, head_dim=self.head_dim, @@ -80,15 +86,10 @@ def replacement( ) return results - pm.register_replacement( - pattern, - replacement, - self.get_inputs(), - pm.fwd_only, - pm_pass - ) - - + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class QKNormFusionPatternWithBias: def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): @@ -100,47 +101,53 @@ def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): self.eps = eps vllm_config = get_current_vllm_config() self.device = vllm_config.device_config.device if vllm_config.device_config else None - self.rmsnorm_matcher = MatcherAscendRMSNormWithBias(eps) def get_inputs(self): T = 5 - qkv = torch.empty(T, self.q_size + 2 * self.kv_size, dtype=torch.bfloat16, device="npu") - q_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + qkv = torch.empty(T, + self.q_size + 2 * self.kv_size, + dtype=torch.bfloat16, + device="npu") + q_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + k_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - + return [qkv, q_weight, k_weight, q_bias, k_bias] - + def register(self, pm_pass: PatternMatcherPass): - def pattern( - qkv: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_bias: torch.Tensor, - k_bias: torch.Tensor - ): - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight, q_bias) - q_flat = q_normed_by_head.view(q.shape) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight, k_bias) - k_flat = k_normed_by_head.view(k.shape) + + def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, q_bias: torch.Tensor, + k_bias: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, + self.eps) + q_normed = q_norm_out + q_bias + q_flat = q_normed.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, + self.eps) + k_normed = k_norm_out + k_bias + k_flat = k_normed.view(k.shape) + return q_flat, k_flat, v - def replacement( - qkv: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_bias: torch.Tensor, - k_bias: torch.Tensor - ): + def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, q_bias: torch.Tensor, + k_bias: torch.Tensor): results = torch.ops.vllm.qk_rmsnorm( - input = qkv, - q_weight = q_weight, - k_weight = k_weight, + input=qkv, + q_weight=q_weight, + k_weight=k_weight, q_hidden_size=self.q_size, kv_hidden_size=self.kv_size, head_dim=self.head_dim, @@ -150,13 +157,8 @@ def replacement( ) return results - pm.register_replacement( - pattern, - replacement, - self.get_inputs(), - pm.fwd_only, - pm_pass - ) + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) class QKNormFusionPass(VllmInductorPass): @@ -184,23 +186,23 @@ def __init__(self, vllm_config: VllmConfig): ) return layer = next(iter(attn_layers.values())) - - for epsilon in [1e-5, 1e-6]: - QKNormFusionPattern( - head_dim=layer.head_size, - num_heads = layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register(self.pattern_match_passes) - + for epsilon in [1e-6, 1e-5]: + QKNormFusionPattern(head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) + + QKNormFusionPatternWithBias(head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() - print("Graph before QK Norm Fusion Pass:") - print(graph.graph) self.matched_count = self.pattern_match_passes.apply(graph) - print("Graph after QK Norm Fusion Pass:") - print(graph.graph) - logging.info("Fused %s QKNorm patterns", self.matched_count) + logging.debug("Fused %s QKNorm patterns", self.matched_count) self.end_and_log() def is_applicable(self, runtime_shape): diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index f9d29af6b63..555d4ca970c 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,8 +20,8 @@ import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa +import vllm_ascend.ops.triton.linear.qk_rmsnorm # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa -import vllm_ascend.ops.triton.linear.qk_rmsnorm # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding) diff --git a/vllm_ascend/ops/triton/linear/qk_rmsnorm.py b/vllm_ascend/ops/triton/linear/qk_rmsnorm.py index ba0cd54c948..18f921f884b 100644 --- a/vllm_ascend/ops/triton/linear/qk_rmsnorm.py +++ b/vllm_ascend/ops/triton/linear/qk_rmsnorm.py @@ -107,6 +107,7 @@ def qk_rmsnorm_triton_kernel( input_offset += input_offset_step output_offset += output_offset_step + @triton.jit def qk_rmsnorm_bias_triton_kernel( input_ptr, @@ -288,9 +289,18 @@ def qk_rmsnorm_impl_fake( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # fake impl for shape inference batch_size = input.shape[0] - q_output = torch.empty(batch_size, q_hidden_size, device=input.device, dtype=input.dtype) - k_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) - v_output = torch.empty(batch_size, kv_hidden_size, device=input.device, dtype=input.dtype) + q_output = torch.empty(batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype) + k_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + v_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) return q_output, k_output, v_output From fb9e37e74bcd296f28fd9b4767127e6b8b70fe33 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 4 Dec 2025 09:40:58 +0000 Subject: [PATCH 03/40] change to norm rope fusion Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compilation/graph_fusion_pass_manager.py | 2 +- ...on_pass.py => qkvnorm_rope_fusion_pass.py} | 82 +++-- vllm_ascend/ops/__init__.py | 2 +- vllm_ascend/ops/triton/linear/qk_rmsnorm.py | 311 ------------------ vllm_ascend/ops/triton/linearnorm/__init__.py | 0 .../linearnorm/split_qkv_rmsnorm_rope.py | 310 +++++++++++++++++ 6 files changed, 372 insertions(+), 335 deletions(-) rename vllm_ascend/compilation/passes/{qknorm_fusion_pass.py => qkvnorm_rope_fusion_pass.py} (74%) delete mode 100644 vllm_ascend/ops/triton/linear/qk_rmsnorm.py create mode 100644 vllm_ascend/ops/triton/linearnorm/__init__.py create mode 100644 vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index c968f0b3e67..b8ef27f0f4a 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -52,5 +52,5 @@ def configure(self, config: VllmConfig): self.passes.append(AddRMSNormQuantFusionPass(config)) if self.ascend_compilation_config.get("fuse_qknorm", True): - from .passes.qknorm_fusion_pass import QKNormFusionPass + from .passes.qkvnorm_rope_fusion_pass import QKNormFusionPass self.passes.append(QKNormFusionPass(config)) diff --git a/vllm_ascend/compilation/passes/qknorm_fusion_pass.py b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py similarity index 74% rename from vllm_ascend/compilation/passes/qknorm_fusion_pass.py rename to vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py index b6c0f1f8cd0..8064d5cab80 100644 --- a/vllm_ascend/compilation/passes/qknorm_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py @@ -25,7 +25,6 @@ from vllm.config import (VllmConfig, get_current_vllm_config, get_layers_from_vllm_config) - class QKNormFusionPattern: def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): @@ -50,30 +49,46 @@ def get_inputs(self): k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - return [qkv, q_weight, k_weight] + cos = torch.empty(1, T, 1, self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, T, 1, self.head_dim, + dtype=torch.bfloat16, + device="npu") + return [qkv, q_weight, k_weight, cos, sin] def register(self, pm_pass: PatternMatcherPass): def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor): + k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, - self.eps) - q_flat = q_norm_out.view(q.shape) + q_norm_out, _ = torch.ops.npu.npu_rms_norm( + q_by_head, q_weight, self.eps) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) + + q_flat = q_norm_out.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + k_flat = k_norm_out.view(k.shape) - return q_flat, k_flat, v + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor): - results = torch.ops.vllm.qk_rmsnorm( + k_weight: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, k_weight=k_weight, @@ -83,7 +98,10 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, eps=self.eps, q_bias=None, k_bias=None, + sin=sin, + cos=cos ) + return results pm.register_replacement(pattern, replacement, self.get_inputs(), @@ -116,14 +134,21 @@ def get_inputs(self): device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + cos = torch.empty(1, T, 1, self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, T, 1, self.head_dim, + dtype=torch.bfloat16, + device="npu") + + return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] - return [qkv, q_weight, k_weight, q_bias, k_bias] def register(self, pm_pass: PatternMatcherPass): def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor): + k_bias: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -131,20 +156,26 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) q_normed = q_norm_out + q_bias - q_flat = q_normed.view(q.shape) - + k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) k_normed = k_norm_out + k_bias + + q_flat = q_normed.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + k_flat = k_normed.view(k.shape) - - return q_flat, k_flat, v + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor): - results = torch.ops.vllm.qk_rmsnorm( + k_bias: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, k_weight=k_weight, @@ -154,6 +185,8 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, eps=self.eps, q_bias=q_bias, k_bias=k_bias, + cos=cos, + sin=sin ) return results @@ -193,15 +226,19 @@ def __init__(self, vllm_config: VllmConfig): eps=epsilon).register( self.pattern_match_passes) - QKNormFusionPatternWithBias(head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register( - self.pattern_match_passes) + # QKNormFusionPatternWithBias(head_dim=layer.head_size, + # num_heads=layer.num_heads, + # num_kv_heads=layer.num_kv_heads, + # eps=epsilon).register( + # self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() + print("before qkvnorm rope fusion pass graph----------------------------") + print(graph.graph) self.matched_count = self.pattern_match_passes.apply(graph) + print("after qkvnorm rope fusion pass graph----------------------------") + print(graph.graph) logging.debug("Fused %s QKNorm patterns", self.matched_count) self.end_and_log() @@ -210,3 +247,4 @@ def is_applicable(self, runtime_shape): Check if the pass is applicable for the current configuration. """ return True + diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 555d4ca970c..558ff6b8102 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -20,7 +20,7 @@ import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa -import vllm_ascend.ops.triton.linear.qk_rmsnorm # noqa +import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( diff --git a/vllm_ascend/ops/triton/linear/qk_rmsnorm.py b/vllm_ascend/ops/triton/linear/qk_rmsnorm.py deleted file mode 100644 index 18f921f884b..00000000000 --- a/vllm_ascend/ops/triton/linear/qk_rmsnorm.py +++ /dev/null @@ -1,311 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -from typing import Optional - -import torch -import triton -import triton.language as tl -import triton.runtime.driver as driver -from vllm.utils.torch_utils import direct_register_custom_op - - -def get_npu_properties(): - device = torch.npu.current_device() - return driver.active.utils.get_device_properties(device) - - -@triton.jit -def qk_rmsnorm_triton_kernel( - input_ptr, - q_ptr, - k_ptr, - v_ptr, - q_weight_ptr, - k_weight_ptr, - batch_size, - q_hidden_size, - kv_hidden_size, - total_hidden_size, - eps, - Q_BLOCK_SIZE: tl.constexpr, - KV_BLOCK_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - row_pid = tl.program_id(0) - col_pid = tl.program_id(1) - row_step = tl.num_programs(0) - weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size - output_offset = row_pid * q_hidden_size - input_offset_step = row_step * total_hidden_size - output_offset_step = row_step * q_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) - valid_mask = col_indices < q_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - Q_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std - output_values = normalized_values * weight_values - tl.store(q_ptr + output_offset + col_indices, - output_values.to(tl.bfloat16).reshape(Q_BLOCK_SIZE), - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size + q_hidden_size - output_offset = row_pid * kv_hidden_size - output_offset_step = row_step * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - KV_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std - output_values = normalized_values * weight_values - tl.store(k_ptr + output_offset + col_indices, - output_values.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size - output_offset = row_pid * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0) - tl.store(v_ptr + output_offset + col_indices, - input_values, - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - - -@triton.jit -def qk_rmsnorm_bias_triton_kernel( - input_ptr, - q_ptr, - k_ptr, - v_ptr, - q_weight_ptr, - q_bias_ptr, - k_weight_ptr, - k_bias_ptr, - batch_size, - q_hidden_size, - kv_hidden_size, - total_hidden_size, - eps, - Q_BLOCK_SIZE: tl.constexpr, - KV_BLOCK_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - row_pid = tl.program_id(0) - col_pid = tl.program_id(1) - row_step = tl.num_programs(0) - - # q norm - weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) - bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size - output_offset = row_pid * q_hidden_size - input_offset_step = row_step * total_hidden_size - output_offset_step = row_step * q_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) - valid_mask = col_indices < q_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - Q_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (Q_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) - output_values = normalized_values * weight_values + bias_values - tl.store(q_ptr + output_offset + col_indices, - output_values.to(tl.bfloat16).reshape(Q_BLOCK_SIZE), - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - - # k norm - weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) - bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) - input_offset = row_pid * total_hidden_size + q_hidden_size - output_offset = row_pid * kv_hidden_size - output_offset_step = row_step * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0).to(tl.float32).reshape( - KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - squares = input_values * input_values - variances = tl.sum(squares, axis=1) / HEAD_DIM - reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - KV_BLOCK_SIZE // HEAD_DIM, 1) - normalized_values = input_values * reciprocal_std # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) - output_values = normalized_values * weight_values + bias_values - tl.store(k_ptr + output_offset + col_indices, - output_values.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - - # v copy - input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size - output_offset = row_pid * kv_hidden_size - for _ in tl.range(row_pid, batch_size, row_step): - col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) - valid_mask = col_indices < kv_hidden_size - input_values = tl.load(input_ptr + input_offset + col_indices, - mask=valid_mask, - other=0.0) - tl.store(v_ptr + output_offset + col_indices, - input_values, - mask=valid_mask) - input_offset += input_offset_step - output_offset += output_offset_step - - -num_core = get_npu_properties()["num_vectorcore"] - - -def qk_rmsnorm_impl( - input: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_hidden_size: int, - kv_hidden_size: int, - head_dim: int, - eps: float, - q_bias: Optional[torch.Tensor] = None, - k_bias: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) - assert KV_BLOCK_SIZE == head_dim - assert q_hidden_size % kv_hidden_size == 0 - - Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim - batch_size = input.shape[0] - total_hidden_size = q_hidden_size + kv_hidden_size * 2 - - q_output = torch.empty(batch_size, - q_hidden_size, - device=input.device, - dtype=input.dtype) - k_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) - v_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) - - n_cols = kv_hidden_size // KV_BLOCK_SIZE - assert num_core % n_cols == 0 - n_rows = num_core // n_cols - - if q_bias is None: - qk_rmsnorm_triton_kernel[(n_rows, n_cols)]( - input, - q_output, - k_output, - v_output, - q_weight, - k_weight, - batch_size, - q_hidden_size, - kv_hidden_size, - total_hidden_size, - eps, - Q_BLOCK_SIZE, - KV_BLOCK_SIZE, - head_dim, - ) - else: - qk_rmsnorm_bias_triton_kernel[(n_rows, n_cols)]( - input, - q_output, - k_output, - v_output, - q_weight, - q_bias, - k_weight, - k_bias, - batch_size, - q_hidden_size, - kv_hidden_size, - total_hidden_size, - eps, - Q_BLOCK_SIZE, - KV_BLOCK_SIZE, - head_dim, - ) - return q_output, k_output, v_output - - -def qk_rmsnorm_impl_fake( - input: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_hidden_size: int, - kv_hidden_size: int, - head_dim: int, - eps: float, - q_bias: Optional[torch.Tensor] = None, - k_bias: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # fake impl for shape inference - batch_size = input.shape[0] - q_output = torch.empty(batch_size, - q_hidden_size, - device=input.device, - dtype=input.dtype) - k_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) - v_output = torch.empty(batch_size, - kv_hidden_size, - device=input.device, - dtype=input.dtype) - return q_output, k_output, v_output - - -direct_register_custom_op(op_name="qk_rmsnorm", - op_func=qk_rmsnorm_impl, - fake_impl=qk_rmsnorm_impl_fake, - mutates_args=[], - dispatch_key="PrivateUse1") diff --git a/vllm_ascend/ops/triton/linearnorm/__init__.py b/vllm_ascend/ops/triton/linearnorm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py new file mode 100644 index 00000000000..cf72a0c39b3 --- /dev/null +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -0,0 +1,310 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import torch +import triton +import triton.language as tl +import triton.runtime.driver as driver +from vllm.utils.torch_utils import direct_register_custom_op +from typing import Optional + +@triton.jit +def split_qkv_rmsnorm_rope_kernel( + input_ptr, + sin_ptr, + cos_ptr, + q_ptr, + k_ptr, + v_ptr, + q_weight_ptr, + q_bias_ptr, + k_weight_ptr, + k_bias_ptr, + batch_size, + q_hidden_size: tl.constexpr, + kv_hidden_size: tl.constexpr, + total_hidden_size: tl.constexpr, + eps: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + BIAS: tl.constexpr, + HEAD_DIM: tl.constexpr, + HALF_HEAD_DIM: tl.constexpr, +): + row_pid = tl.program_id(0) + col_pid = tl.program_id(1) + row_step = tl.num_programs(0) + # q + weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + if BIAS: + bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + output_offset = row_pid * q_hidden_size + input_offset_step = row_step * total_hidden_size + output_offset_step = row_step * q_hidden_size + for row_idx in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) + valid_mask = col_indices < q_hidden_size + input_values = ( + tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) + .to(tl.float32) + .reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + ) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + Q_BLOCK_SIZE // HEAD_DIM, 1 + ) + normalized_values = ( + input_values * reciprocal_std + ) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) + if BIAS: + normalized_values = (normalized_values * weight_values + bias_values).to( + tl.bfloat16 + ) + else: + normalized_values = (normalized_values * weight_values).to(tl.bfloat16) + # rope + sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) + sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) + cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + x1 = tl.extract_slice( + normalized_values, + offsets=(0, 0), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + x2 = tl.extract_slice( + normalized_values, + offsets=(0, HALF_HEAD_DIM), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + cat_x = tl.insert_slice( + cat_x, + -x2, + offsets=(0, 0), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.insert_slice( + cat_x, + x1, + offsets=(0, HALF_HEAD_DIM), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + roped_q = cat_x * sin + normalized_values * cos + # store + tl.store( + q_ptr + output_offset + col_indices, + roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), + mask=valid_mask, + ) + input_offset += input_offset_step + output_offset += output_offset_step + + # k + weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + if BIAS: + bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + q_hidden_size + output_offset = row_pid * kv_hidden_size + output_offset_step = row_step * kv_hidden_size + for row_idx in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = ( + tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) + .to(tl.float32) + .reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) + ) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + KV_BLOCK_SIZE // HEAD_DIM, 1 + ) + normalized_values = ( + input_values * reciprocal_std + ) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + if BIAS: + normalized_values = (normalized_values * weight_values + bias_values).to( + tl.bfloat16 + ) + else: + normalized_values = (normalized_values * weight_values).to(tl.bfloat16) + # # rope + sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) + sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) + cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + x1 = tl.extract_slice( + normalized_values, + offsets=(0, 0), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + x2 = tl.extract_slice( + normalized_values, + offsets=(0, HALF_HEAD_DIM), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + cat_x = tl.insert_slice( + cat_x, + -x2, + offsets=(0, 0), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.insert_slice( + cat_x, + x1, + offsets=(0, HALF_HEAD_DIM), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + roped_k = cat_x * sin + normalized_values * cos + # store + tl.store( + k_ptr + output_offset + col_indices, + roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), + mask=valid_mask, + ) + input_offset += input_offset_step + output_offset += output_offset_step + + # v + input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size + output_offset = row_pid * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load( + input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0 + ) + tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + + +kernels = {} + +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + +num_vectorcore = get_npu_properties()["num_vectorcore"] + +def split_qkv_rmsnorm_rope_impl( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor], + k_bias: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) + assert KV_BLOCK_SIZE == head_dim + assert q_hidden_size % kv_hidden_size == 0 + Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim + batch_size = input.shape[0] + total_hidden_size = q_hidden_size + kv_hidden_size * 2 + q_output = torch.empty( + batch_size, q_hidden_size, device=input.device, dtype=input.dtype + ) + k_output = torch.empty( + batch_size, kv_hidden_size, device=input.device, dtype=input.dtype + ) + v_output = torch.empty( + batch_size, kv_hidden_size, device=input.device, dtype=input.dtype + ) + n_cols = kv_hidden_size // KV_BLOCK_SIZE + assert num_vectorcore % n_cols == 0 + n_rows = num_vectorcore // n_cols + BIAS = q_bias is not None + + split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( + input, + sin, + cos, + q_output, + k_output, + v_output, + q_weight, + q_bias, + k_weight, + k_bias, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + BIAS, + head_dim, + head_dim // 2, + ) + return q_output, k_output, v_output + +def split_qkv_rmsnorm_rope_impl_fake( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor] = None, + k_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Fake implementation for shape inference during Dynamo/AOT tracing. + # Note: sin and cos are not used in shape computation, but must be present in signature. + batch_size = input.shape[0] + q_output = torch.empty( + batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype, + ) + k_output = torch.empty( + batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + v_output = torch.empty( + batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + return q_output, k_output, v_output + +direct_register_custom_op(op_name="qkv_rmsnorm_rope", + op_func=split_qkv_rmsnorm_rope_impl, + fake_impl=split_qkv_rmsnorm_rope_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") From 78cc0a2d819692b7bbe1f5de397fd3796abe70de Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 4 Dec 2025 09:49:17 +0000 Subject: [PATCH 04/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py index 8064d5cab80..e87e6f1f56f 100644 --- a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py @@ -152,12 +152,12 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q_by_head = q.view(*q.shape[:-1], self.num_heads, self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) q_normed = q_norm_out + q_bias - k_by_head = k.view(*k.shape[:-1], self.num_kv_heads, self.head_dim) + k_by_head = k.view(*k.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) k_normed = k_norm_out + k_bias From ef734cbf221e00248426ec97d4d2694be3925007 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 4 Dec 2025 11:59:29 +0000 Subject: [PATCH 05/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compilation/passes/qkvnorm_rope_fusion_pass.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py index e87e6f1f56f..aed89731d0c 100644 --- a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py @@ -157,7 +157,7 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, self.eps) q_normed = q_norm_out + q_bias - k_by_head = k.view(*k.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) k_normed = k_norm_out + k_bias @@ -226,11 +226,11 @@ def __init__(self, vllm_config: VllmConfig): eps=epsilon).register( self.pattern_match_passes) - # QKNormFusionPatternWithBias(head_dim=layer.head_size, - # num_heads=layer.num_heads, - # num_kv_heads=layer.num_kv_heads, - # eps=epsilon).register( - # self.pattern_match_passes) + QKNormFusionPatternWithBias(head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() From e89886766d64926db84266891e1331bacb09469f Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 5 Dec 2025 04:01:25 +0000 Subject: [PATCH 06/40] normalize fusion naming and format code Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 2 + .../compilation/graph_fusion_pass_manager.py | 6 +- ...ion_pass.py => qknorm_rope_fusion_pass.py} | 159 ++++++++++-------- .../linearnorm/split_qkv_rmsnorm_rope.py | 146 ++++++++-------- 4 files changed, 172 insertions(+), 141 deletions(-) rename vllm_ascend/compilation/passes/{qkvnorm_rope_fusion_pass.py => qknorm_rope_fusion_pass.py} (69%) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 716743e00a0..01f988859be 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -246,6 +246,8 @@ def __init__(self, fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. Default: True + fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. + Default: True **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.fuse_norm_quant = fuse_norm_quant diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index b8ef27f0f4a..e311b2602a7 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -51,6 +51,6 @@ def configure(self, config: VllmConfig): AddRMSNormQuantFusionPass self.passes.append(AddRMSNormQuantFusionPass(config)) - if self.ascend_compilation_config.get("fuse_qknorm", True): - from .passes.qkvnorm_rope_fusion_pass import QKNormFusionPass - self.passes.append(QKNormFusionPass(config)) + if self.ascend_compilation_config.get("fuse_qknorm_rope", True): + from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass + self.passes.append(QKNormRopeFusionPass(config)) diff --git a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py similarity index 69% rename from vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py rename to vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index aed89731d0c..ea900c646fe 100644 --- a/vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -19,13 +19,15 @@ import torch import torch._inductor.pattern_matcher as pm -from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._inductor.pattern_matcher import (PatternMatcherPass, + PatternPrettyPrinter) from vllm.attention.layer import Attention from vllm.compilation.vllm_inductor_pass import VllmInductorPass from vllm.config import (VllmConfig, get_current_vllm_config, get_layers_from_vllm_config) -class QKNormFusionPattern: + +class QKNormRopeFusionPattern: def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): self.head_dim = head_dim @@ -49,26 +51,33 @@ def get_inputs(self): k_weight = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, T, 1, self.head_dim, - dtype=torch.bfloat16, - device="npu") + cos = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") return [qkv, q_weight, k_weight, cos, sin] def register(self, pm_pass: PatternMatcherPass): def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, - k_weight: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - + k_weight: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_norm_out, _ = torch.ops.npu.npu_rms_norm( - q_by_head, q_weight, self.eps) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, + self.eps) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) @@ -76,13 +85,16 @@ def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, self.eps) q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, + self.head_dim) + k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) - + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, + self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin) + return q_rope, k_rope, v def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, @@ -99,16 +111,15 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, q_bias=None, k_bias=None, sin=sin, - cos=cos - ) - + cos=cos) + return results pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class QKNormFusionPatternWithBias: +class QKNormRopeFusionPatternWithBias: def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): self.head_dim = head_dim @@ -134,47 +145,59 @@ def get_inputs(self): device="npu") q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") - cos = torch.empty(1, T, 1, self.head_dim, - dtype=torch.bfloat16, - device="npu") - sin = torch.empty(1, T, 1, self.head_dim, - dtype=torch.bfloat16, - device="npu") + cos = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] - def register(self, pm_pass: PatternMatcherPass): def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + k_bias: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, self.eps) q_normed = q_norm_out + q_bias - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, self.eps) k_normed = k_norm_out + k_bias - + q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, + self.head_dim) + k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb(q_reshape, k_reshape, cos, sin) - + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, + self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin) + return q_rope, k_rope, v def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, k_weight: torch.Tensor, q_bias: torch.Tensor, - k_bias: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + k_bias: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): results = torch.ops.vllm.qkv_rmsnorm_rope( input=qkv, q_weight=q_weight, @@ -186,15 +209,14 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, q_bias=q_bias, k_bias=k_bias, cos=cos, - sin=sin - ) + sin=sin) return results pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class QKNormFusionPass(VllmInductorPass): +class QKNormRopeFusionPass(VllmInductorPass): """ A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator. """ @@ -202,44 +224,48 @@ class QKNormFusionPass(VllmInductorPass): def __init__(self, vllm_config: VllmConfig): super().__init__(vllm_config) self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( - pass_name="qknorm_fusion_pass") + pass_name="qknorm_rope_fusion_pass") dtype = vllm_config.model_config.dtype if dtype not in (torch.bfloat16, torch.float16): - logging.info("QKNorm fusion not enabled: unsupported dtype %s", - dtype) + logging.info( + "QKNorm and Rope fusion not enabled: unsupported dtype %s", + dtype) return - # use one attn layer to get meta (such as head_dim) for QkNormFusionPattern + # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern attn_layers: dict[str, Attention] = get_layers_from_vllm_config( vllm_config, Attention) if len(attn_layers) == 0: logging.info( - "QK Norm fusion enabled, but no Attention layers were discovered." + "QKNorm and Rope fusion enabled, but no Attention layers were discovered." ) return layer = next(iter(attn_layers.values())) for epsilon in [1e-6, 1e-5]: - QKNormFusionPattern(head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register( - self.pattern_match_passes) - - QKNormFusionPatternWithBias(head_dim=layer.head_size, - num_heads=layer.num_heads, - num_kv_heads=layer.num_kv_heads, - eps=epsilon).register( - self.pattern_match_passes) + QKNormRopeFusionPattern(head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) + + QKNormRopeFusionPatternWithBias(head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin() - print("before qkvnorm rope fusion pass graph----------------------------") - print(graph.graph) self.matched_count = self.pattern_match_passes.apply(graph) - print("after qkvnorm rope fusion pass graph----------------------------") - print(graph.graph) - logging.debug("Fused %s QKNorm patterns", self.matched_count) + logging.debug("Fused %s QKNorm and Rope patterns", self.matched_count) + logging.debug("Patterns registered for replacement:") + pattern_idx = 0 + for pattern_entry in self.pattern_match_passes.patterns.values(): + for p in pattern_entry: + p_str = PatternPrettyPrinter.run(p.pattern) + logging.debug("Pattern %d: %s", pattern_idx, p_str) + pattern_idx += 1 self.end_and_log() def is_applicable(self, runtime_shape): @@ -247,4 +273,3 @@ def is_applicable(self, runtime_shape): Check if the pass is applicable for the current configuration. """ return True - diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index cf72a0c39b3..9d51e0a1ea2 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -14,12 +14,14 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # +from typing import Optional + import torch import triton import triton.language as tl import triton.runtime.driver as driver from vllm.utils.torch_utils import direct_register_custom_op -from typing import Optional + @triton.jit def split_qkv_rmsnorm_rope_kernel( @@ -58,26 +60,23 @@ def split_qkv_rmsnorm_rope_kernel( for row_idx in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) valid_mask = col_indices < q_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) + input_values = (tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) squares = input_values * input_values variances = tl.sum(squares, axis=1) / HEAD_DIM reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - Q_BLOCK_SIZE // HEAD_DIM, 1 - ) - normalized_values = ( - input_values * reciprocal_std - ) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) + Q_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = (input_values * reciprocal_std + ) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to( - tl.bfloat16 - ) + normalized_values = (normalized_values * weight_values + + bias_values).to(tl.bfloat16) else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - # rope + normalized_values = (normalized_values * weight_values).to( + tl.bfloat16) + sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) @@ -93,7 +92,8 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) cat_x = tl.insert_slice( cat_x, -x2, @@ -109,7 +109,6 @@ def split_qkv_rmsnorm_rope_kernel( strides=(1, 1), ) roped_q = cat_x * sin + normalized_values * cos - # store tl.store( q_ptr + output_offset + col_indices, roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), @@ -118,7 +117,6 @@ def split_qkv_rmsnorm_rope_kernel( input_offset += input_offset_step output_offset += output_offset_step - # k weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) if BIAS: bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) @@ -128,26 +126,22 @@ def split_qkv_rmsnorm_rope_kernel( for row_idx in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) valid_mask = col_indices < kv_hidden_size - input_values = ( - tl.load(input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0) - .to(tl.float32) - .reshape(KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM) - ) + input_values = (tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) squares = input_values * input_values variances = tl.sum(squares, axis=1) / HEAD_DIM reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( - KV_BLOCK_SIZE // HEAD_DIM, 1 - ) - normalized_values = ( - input_values * reciprocal_std - ) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + KV_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = (input_values * reciprocal_std + ) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) if BIAS: - normalized_values = (normalized_values * weight_values + bias_values).to( - tl.bfloat16 - ) + normalized_values = (normalized_values * weight_values + + bias_values).to(tl.bfloat16) else: - normalized_values = (normalized_values * weight_values).to(tl.bfloat16) - # # rope + normalized_values = (normalized_values * weight_values).to( + tl.bfloat16) sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) @@ -163,7 +157,8 @@ def split_qkv_rmsnorm_rope_kernel( sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), strides=(1, 1), ) - cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16) + cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) cat_x = tl.insert_slice( cat_x, -x2, @@ -179,7 +174,7 @@ def split_qkv_rmsnorm_rope_kernel( strides=(1, 1), ) roped_k = cat_x * sin + normalized_values * cos - # store + tl.store( k_ptr + output_offset + col_indices, roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), @@ -188,28 +183,32 @@ def split_qkv_rmsnorm_rope_kernel( input_offset += input_offset_step output_offset += output_offset_step - # v input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size output_offset = row_pid * kv_hidden_size for _ in tl.range(row_pid, batch_size, row_step): col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) valid_mask = col_indices < kv_hidden_size - input_values = tl.load( - input_ptr + input_offset + col_indices, mask=valid_mask, other=0.0 - ) - tl.store(v_ptr + output_offset + col_indices, input_values, mask=valid_mask) + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0) + tl.store(v_ptr + output_offset + col_indices, + input_values, + mask=valid_mask) input_offset += input_offset_step output_offset += output_offset_step kernels = {} + def get_npu_properties(): device = torch.npu.current_device() return driver.active.utils.get_device_properties(device) + num_vectorcore = get_npu_properties()["num_vectorcore"] + def split_qkv_rmsnorm_rope_impl( input: torch.Tensor, sin: torch.Tensor, @@ -229,44 +228,48 @@ def split_qkv_rmsnorm_rope_impl( Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim batch_size = input.shape[0] total_hidden_size = q_hidden_size + kv_hidden_size * 2 - q_output = torch.empty( - batch_size, q_hidden_size, device=input.device, dtype=input.dtype - ) - k_output = torch.empty( - batch_size, kv_hidden_size, device=input.device, dtype=input.dtype - ) - v_output = torch.empty( - batch_size, kv_hidden_size, device=input.device, dtype=input.dtype - ) + q_output = torch.empty(batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype) + k_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + v_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) n_cols = kv_hidden_size // KV_BLOCK_SIZE assert num_vectorcore % n_cols == 0 n_rows = num_vectorcore // n_cols BIAS = q_bias is not None split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( - input, - sin, - cos, - q_output, - k_output, - v_output, - q_weight, - q_bias, - k_weight, - k_bias, - batch_size, - q_hidden_size, - kv_hidden_size, - total_hidden_size, - eps, - Q_BLOCK_SIZE, - KV_BLOCK_SIZE, - BIAS, - head_dim, - head_dim // 2, - ) + input, + sin, + cos, + q_output, + k_output, + v_output, + q_weight, + q_bias, + k_weight, + k_bias, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + BIAS, + head_dim, + head_dim // 2, + ) return q_output, k_output, v_output + def split_qkv_rmsnorm_rope_impl_fake( input: torch.Tensor, sin: torch.Tensor, @@ -303,6 +306,7 @@ def split_qkv_rmsnorm_rope_impl_fake( ) return q_output, k_output, v_output + direct_register_custom_op(op_name="qkv_rmsnorm_rope", op_func=split_qkv_rmsnorm_rope_impl, fake_impl=split_qkv_rmsnorm_rope_impl_fake, From b20db1d453b687509b6567292d9d7869b8183d50 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 5 Dec 2025 07:37:03 +0000 Subject: [PATCH 07/40] move special operator to attention metadata builder Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 15 ++++++++++++++- vllm_ascend/ops/rotary_embedding.py | 16 +--------------- vllm_ascend/worker/model_runner_v1.py | 8 ++++++-- 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index b4343e76f0c..d9f619c7723 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -44,6 +44,7 @@ def set_ascend_forward_context( prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, weight_prefetch_method: Optional[WeightPrefetchMethod] = None, + cos_sin_cache: Optional[torch.Tensor] = None, is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -127,7 +128,19 @@ def set_ascend_forward_context( forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model - + + # initialize rope + if cos_sin_cache is not None: + last_dim = cos_sin_cache.size()[-1] + cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + forward_context.cos = cos.view(1, -1, 1, last_dim).contiguous() + forward_context.sin = sin.view(1, -1, 1, last_dim).contiguous() + else: + forward_context.cos = None + forward_context.sin = None + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ef398faef00..8133f933f79 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -75,7 +75,7 @@ def _rope_forward_oot( # Although this function modifies in-place, please retain the function's return value. # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, self.cos, self.sin) + query, key, forward_context.cos, forward_context.sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) @@ -141,20 +141,6 @@ def forward_oot( is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - forward_context = get_forward_context() - is_first_layer = forward_context.is_first_layer - # Generate cos and sin outside layers to avoid repeated calculation. - if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ - -1] == 128: - if is_first_layer: - cos_sin = self.cos_sin_cache.index_select(0, positions) - last_dim = cos_sin.size()[-1] - cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( - 1, 1, 2).chunk(2, dim=-2) - # BSNH - self.cos = cos.view(1, -1, 1, last_dim).contiguous() - self.sin = sin.view(1, -1, 1, last_dim).contiguous() - forward_context.is_first_layer = False return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d4b4b25bf74..60918dcb7e8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1444,7 +1444,9 @@ def execute_model( total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method): + weight_prefetch_method=self.weight_prefetch_method, + cos_sin_cache=self.model.model.layers[ + self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions)): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2159,7 +2161,9 @@ def dummy_drafter_compute_logits(hidden_states): batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method): + weight_prefetch_method=self.weight_prefetch_method, + cos_sin_cache=self.model.model.layers[ + self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions)): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From 73766176c84b5d45e81045d6fea09592bdb8ce33 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 7 Dec 2025 11:57:24 +0800 Subject: [PATCH 08/40] add e2e test Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compile/test_qknorm_rope_fusion.py | 247 ++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py diff --git a/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py new file mode 100644 index 00000000000..0d01c9e50af --- /dev/null +++ b/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py @@ -0,0 +1,247 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List + +# import pytest +import torch +import torch.nn as nn +import torch_npu +import vllm.config +from vllm.compilation.fx_utils import OpOverload +from vllm.config import ModelConfig, VllmConfig +from tests.e2e.singlecard.compile.backend import TestBackend +from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import \ + QKNormRopeFusionPass + + +class TestQKNormRopeModelNoBias(nn.Module): + """ + A minimal test model that simulates the pattern: + QKV split → Q RMSNorm → K RMSNorm → Reshape → RoPE (no bias) + """ + + def __init__(self, head_dim: int, num_heads: int, num_kv_heads: int, + eps: float = 1e-6, device: str = "npu"): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.eps = eps + + # Calculate sizes + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Parameters + self.q_weight = nn.Parameter(torch.randn(head_dim, device=device)) + self.k_weight = nn.Parameter(torch.randn(head_dim, device=device)) + + # RoPE parameters + self.cos_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) + self.sin_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) + + self.seq_len = None # To be set during forward pass + + def forward(self, qkv): + """ + Forward pass simulating the unfused pattern (no bias) + """ + seq_len = qkv.shape[0] + cos = self.cos_weight.expand(1, seq_len, 1, self.head_dim) + sin = self.sin_weight.expand(1, seq_len, 1, self.head_dim) + + # Split QKV + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch_npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) + q_flat = q_norm_out.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch_npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) + k_flat = k_norm_out.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + # Apply RoPE + q_rope, k_rope = torch_npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin + ) + + return q_rope, k_rope, v + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_apply_rotary_pos_emb.default + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [torch.ops.vllm.qkv_rmsnorm_rope.default] + + +class TestQKNormRopeModelWithBias(nn.Module): + """ + A minimal test model that simulates the pattern: + QKV split → Q RMSNorm → Q Bias → K RMSNorm → K Bias → Reshape → RoPE (with bias) + """ + + def __init__(self, head_dim: int, num_heads: int, num_kv_heads: int, + eps: float = 1e-6, device: str = "npu"): + super().__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.eps = eps + + # Calculate sizes + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Parameters + self.q_weight = nn.Parameter(torch.randn(head_dim, device=device)) + self.k_weight = nn.Parameter(torch.randn(head_dim, device=device)) + self.q_bias = nn.Parameter(torch.randn(head_dim, device=device)) + self.k_bias = nn.Parameter(torch.randn(head_dim, device=device)) + + # RoPE parameters + self.cos_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) + self.sin_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) + + def forward(self, qkv): + """ + Forward pass simulating the unfused pattern (with bias) + """ + seq_len = qkv.shape[0] + cos = self.cos_weight.expand(1, seq_len, 1, self.head_dim) + sin = self.sin_weight.expand(1, seq_len, 1, self.head_dim) + + # Split QKV + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_norm_out, _ = torch_npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) + q_normed = q_norm_out + self.q_bias + q_flat = q_normed.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_norm_out, _ = torch_npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) + k_normed = k_norm_out + self.k_bias + k_flat = k_normed.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) + + # Apply RoPE + q_rope, k_rope = torch_npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin + ) + + return q_rope, k_rope, v + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_apply_rotary_pos_emb.default, + torch.ops.aten.add.Tensor + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [torch.ops.vllm.qkv_rmsnorm_rope.default] + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("seq_len", [1, 128, 257]) +# @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +# @pytest.mark.parametrize("with_bias", [False, True]) +def test_qknorm_rope_fusion( + dtype: torch.dtype, + seq_len: int, + eps: float, + with_bias: bool, +): + """ + End-to-end test for QKV split + RMSNorm + RoPE fusion. + Tests both with and without bias versions. + """ + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + # Model parameters + head_dim = 128 + num_heads = 32 + num_kv_heads = 8 + + vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) + + with vllm.config.set_current_vllm_config(vllm_config): + # Create backend with the fusion pass + backend = TestBackend(custom_passes=[QKNormRopeFusionPass(vllm_config)]) + + # Create appropriate test model based on bias flag + if with_bias: + model = TestQKNormRopeModelWithBias( + head_dim=head_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + eps=eps, + device="npu" + ) + else: + model = TestQKNormRopeModelNoBias( + head_dim=head_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + eps=eps, + device="npu" + ) + + model = model.to("npu") + + # Create test input + qkv_size = num_heads * head_dim + 2 * num_kv_heads * head_dim + x = torch.rand( + seq_len, + qkv_size, + device="npu", + dtype=dtype, + requires_grad=False + ) + + # Run unfused model + result_unfused = model(x) + print(f"Unfused result shapes: {[t.shape for t in result_unfused]}") + + # Compile with fusion + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) + print(f"Fused result shapes: {[t.shape for t in result_fused]}") + + print("=== Checking operator fusion ===") + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) + + +if __name__ == "__main__": + + test_qknorm_rope_fusion( + dtype=torch.float16, + seq_len=128, + eps=1e-6, + with_bias=False, + ) From 65ce080310870501bcdb45b5182d2da8d6909866 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 8 Dec 2025 08:21:54 +0000 Subject: [PATCH 09/40] remove first layer change Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 13 ------------- vllm_ascend/ops/rotary_embedding.py | 16 +++++++++++++++- vllm_ascend/worker/model_runner_v1.py | 10 ++-------- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index d9f619c7723..27c65c27d63 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -44,7 +44,6 @@ def set_ascend_forward_context( prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, weight_prefetch_method: Optional[WeightPrefetchMethod] = None, - cos_sin_cache: Optional[torch.Tensor] = None, is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -128,18 +127,6 @@ def set_ascend_forward_context( forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model - - # initialize rope - if cos_sin_cache is not None: - last_dim = cos_sin_cache.size()[-1] - cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( - 1, 1, 2).chunk(2, dim=-2) - # BSNH - forward_context.cos = cos.view(1, -1, 1, last_dim).contiguous() - forward_context.sin = sin.view(1, -1, 1, last_dim).contiguous() - else: - forward_context.cos = None - forward_context.sin = None if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 8133f933f79..ef398faef00 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -75,7 +75,7 @@ def _rope_forward_oot( # Although this function modifies in-place, please retain the function's return value. # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, forward_context.cos, forward_context.sin) + query, key, self.cos, self.sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) @@ -141,6 +141,20 @@ def forward_oot( is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override + forward_context = get_forward_context() + is_first_layer = forward_context.is_first_layer + # Generate cos and sin outside layers to avoid repeated calculation. + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128: + if is_first_layer: + cos_sin = self.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + self.cos = cos.view(1, -1, 1, last_dim).contiguous() + self.sin = sin.view(1, -1, 1, last_dim).contiguous() + forward_context.is_first_layer = False return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 60918dcb7e8..b294292c5ed 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1443,10 +1443,7 @@ def execute_model( num_actual_tokens=scheduler_output. total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, - model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, - cos_sin_cache=self.model.model.layers[ - self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions)): + model_instance=self.model): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2160,10 +2157,7 @@ def dummy_drafter_compute_logits(hidden_states): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, - model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, - cos_sin_cache=self.model.model.layers[ - self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions)): + model_instance=self.model): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From 691d54c433374de8533278b07a9d853a390f711a Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 8 Dec 2025 08:35:22 +0000 Subject: [PATCH 10/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 8 +++++++- vllm_ascend/ops/rotary_embedding.py | 16 +--------------- vllm_ascend/worker/model_runner_v1.py | 20 ++++++++++++++++++-- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 27c65c27d63..8faa01fa9c5 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -44,6 +44,8 @@ def set_ascend_forward_context( prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, weight_prefetch_method: Optional[WeightPrefetchMethod] = None, + cos: Optional[torch.Tensor] = None, + sin: Optional[torch.Tensor] = None, is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -127,7 +129,11 @@ def set_ascend_forward_context( forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model - + + # initialize rope + forward_context.cos = cos + forward_context.sin = sin + if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ef398faef00..8133f933f79 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -75,7 +75,7 @@ def _rope_forward_oot( # Although this function modifies in-place, please retain the function's return value. # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, self.cos, self.sin) + query, key, forward_context.cos, forward_context.sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) @@ -141,20 +141,6 @@ def forward_oot( is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - forward_context = get_forward_context() - is_first_layer = forward_context.is_first_layer - # Generate cos and sin outside layers to avoid repeated calculation. - if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ - -1] == 128: - if is_first_layer: - cos_sin = self.cos_sin_cache.index_select(0, positions) - last_dim = cos_sin.size()[-1] - cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( - 1, 1, 2).chunk(2, dim=-2) - # BSNH - self.cos = cos.view(1, -1, 1, last_dim).contiguous() - self.sin = sin.view(1, -1, 1, last_dim).contiguous() - forward_context.is_first_layer = False return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b294292c5ed..ef36dc4d242 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1443,7 +1443,10 @@ def execute_model( num_actual_tokens=scheduler_output. total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, - model_instance=self.model): + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_method, + cos=self.cos[:, :maybe_padded_num_tokens], + sin=self.sin[:, :maybe_padded_num_tokens]): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2145,6 +2148,16 @@ def dummy_drafter_compute_logits(hidden_states): self.drafter.model, "compute_logits"): return self.drafter.model.compute_logits( hidden_states[dummy_indices]) + + # initialize rope + cos_sin_cache=self.model.model.layers[ + self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions) + last_dim = cos_sin_cache.size()[-1] + cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( + 1, 1, 2).chunk(2, dim=-2) + # BSNH + self.cos[:, :maybe_padded_num_tokens] = cos.view(1, -1, 1, last_dim).contiguous() + self.sin[:, :maybe_padded_num_tokens] = sin.view(1, -1, 1, last_dim).contiguous() with set_ascend_forward_context( attn_metadata, @@ -2157,7 +2170,10 @@ def dummy_drafter_compute_logits(hidden_states): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, - model_instance=self.model): + model_instance=self.model, + weight_prefetch_method=self.weight_prefetch_metho, + cos=self.cos[:, :num_tokens], + sin=self.sin[:, :num_tokens]): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From 6e952395260d234a29874054ec5d9a2fde9bc62c Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 8 Dec 2025 08:55:08 +0000 Subject: [PATCH 11/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ef36dc4d242..d27fc3347de 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2156,8 +2156,8 @@ def dummy_drafter_compute_logits(hidden_states): cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( 1, 1, 2).chunk(2, dim=-2) # BSNH - self.cos[:, :maybe_padded_num_tokens] = cos.view(1, -1, 1, last_dim).contiguous() - self.sin[:, :maybe_padded_num_tokens] = sin.view(1, -1, 1, last_dim).contiguous() + self.cos[:, :num_tokens] = cos.view(1, -1, 1, last_dim).contiguous() + self.sin[:, :num_tokens] = sin.view(1, -1, 1, last_dim).contiguous() with set_ascend_forward_context( attn_metadata, @@ -2171,7 +2171,7 @@ def dummy_drafter_compute_logits(hidden_states): batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_metho, + weight_prefetch_method=self.weight_prefetch_method, cos=self.cos[:, :num_tokens], sin=self.sin[:, :num_tokens]): hidden_states = self._generate_dummy_run_hidden_states( From fef462ffbf9e4f20baf867f8df3510039c8ee054 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 8 Dec 2025 09:18:55 +0000 Subject: [PATCH 12/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d27fc3347de..8baff88d5f5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1445,8 +1445,8 @@ def execute_model( prefetch_stream=self.prefetch_stream, model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :maybe_padded_num_tokens], - sin=self.sin[:, :maybe_padded_num_tokens]): + cos=self.cos[:, :maybe_padded_num_tokens] if self.cos is not None else None, + sin=self.sin[:, :maybe_padded_num_tokens] if self.sin is not None else None): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2172,8 +2172,8 @@ def dummy_drafter_compute_logits(hidden_states): prefetch_stream=self.prefetch_stream, model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :num_tokens], - sin=self.sin[:, :num_tokens]): + cos=self.cos[:, :num_tokens] if self.cos is not None else None, + sin=self.sin[:, :num_tokens] if self.sin is not None else None): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From 6f72f7fdbc2345d0b6720df855e39c7315c25ff1 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 8 Dec 2025 10:13:47 +0000 Subject: [PATCH 13/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compile/test_qknorm_rope_fusion.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py index 0d01c9e50af..7647ba814a9 100644 --- a/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py +++ b/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py @@ -16,7 +16,7 @@ # from typing import List -# import pytest +import pytest import torch import torch.nn as nn import torch_npu @@ -165,10 +165,10 @@ def ops_in_model_after(self) -> List[OpOverload]: return [torch.ops.vllm.qkv_rmsnorm_rope.default] -# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize("seq_len", [1, 128, 257]) -# @pytest.mark.parametrize("eps", [1e-5, 1e-6]) -# @pytest.mark.parametrize("with_bias", [False, True]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("seq_len", [1, 128]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("with_bias", [False]) def test_qknorm_rope_fusion( dtype: torch.dtype, seq_len: int, @@ -228,20 +228,11 @@ def test_qknorm_rope_fusion( print(f"Unfused result shapes: {[t.shape for t in result_unfused]}") # Compile with fusion - model_fused = torch.compile(model, backend=backend) - result_fused = model_fused(x) + with torch.no_grad(): + model_fused = torch.compile(model, backend=backend) + result_fused = model_fused(x) print(f"Fused result shapes: {[t.shape for t in result_fused]}") print("=== Checking operator fusion ===") backend.check_before_ops(model.ops_in_model_before()) - backend.check_after_ops(model.ops_in_model_after()) - - -if __name__ == "__main__": - - test_qknorm_rope_fusion( - dtype=torch.float16, - seq_len=128, - eps=1e-6, - with_bias=False, - ) + backend.check_after_ops(model.ops_in_model_after()) \ No newline at end of file From 171539a511e4995daeb80d52cfa97256dafa24ee Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 9 Dec 2025 06:20:17 +0000 Subject: [PATCH 14/40] remove e2e test Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compile/test_qknorm_rope_fusion.py | 238 ------------------ 1 file changed, 238 deletions(-) delete mode 100644 tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py diff --git a/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py b/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py deleted file mode 100644 index 7647ba814a9..00000000000 --- a/tests/e2e/singlecard/compile/test_qknorm_rope_fusion.py +++ /dev/null @@ -1,238 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import List - -import pytest -import torch -import torch.nn as nn -import torch_npu -import vllm.config -from vllm.compilation.fx_utils import OpOverload -from vllm.config import ModelConfig, VllmConfig -from tests.e2e.singlecard.compile.backend import TestBackend -from vllm_ascend.compilation.passes.qknorm_rope_fusion_pass import \ - QKNormRopeFusionPass - - -class TestQKNormRopeModelNoBias(nn.Module): - """ - A minimal test model that simulates the pattern: - QKV split → Q RMSNorm → K RMSNorm → Reshape → RoPE (no bias) - """ - - def __init__(self, head_dim: int, num_heads: int, num_kv_heads: int, - eps: float = 1e-6, device: str = "npu"): - super().__init__() - self.head_dim = head_dim - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.eps = eps - - # Calculate sizes - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - # Parameters - self.q_weight = nn.Parameter(torch.randn(head_dim, device=device)) - self.k_weight = nn.Parameter(torch.randn(head_dim, device=device)) - - # RoPE parameters - self.cos_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) - self.sin_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) - - self.seq_len = None # To be set during forward pass - - def forward(self, qkv): - """ - Forward pass simulating the unfused pattern (no bias) - """ - seq_len = qkv.shape[0] - cos = self.cos_weight.expand(1, seq_len, 1, self.head_dim) - sin = self.sin_weight.expand(1, seq_len, 1, self.head_dim) - - # Split QKV - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_norm_out, _ = torch_npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) - q_flat = q_norm_out.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_norm_out, _ = torch_npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) - k_flat = k_norm_out.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - # Apply RoPE - q_rope, k_rope = torch_npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin - ) - - return q_rope, k_rope, v - - def ops_in_model_before(self) -> List[OpOverload]: - """Return the list of expected operators BEFORE fusion.""" - return [ - torch.ops.npu.npu_apply_rotary_pos_emb.default - ] - - def ops_in_model_after(self) -> List[OpOverload]: - """Return the list of expected operators AFTER successful fusion.""" - return [torch.ops.vllm.qkv_rmsnorm_rope.default] - - -class TestQKNormRopeModelWithBias(nn.Module): - """ - A minimal test model that simulates the pattern: - QKV split → Q RMSNorm → Q Bias → K RMSNorm → K Bias → Reshape → RoPE (with bias) - """ - - def __init__(self, head_dim: int, num_heads: int, num_kv_heads: int, - eps: float = 1e-6, device: str = "npu"): - super().__init__() - self.head_dim = head_dim - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.eps = eps - - # Calculate sizes - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - - # Parameters - self.q_weight = nn.Parameter(torch.randn(head_dim, device=device)) - self.k_weight = nn.Parameter(torch.randn(head_dim, device=device)) - self.q_bias = nn.Parameter(torch.randn(head_dim, device=device)) - self.k_bias = nn.Parameter(torch.randn(head_dim, device=device)) - - # RoPE parameters - self.cos_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) - self.sin_weight = nn.Parameter(torch.randn(1, 1, 1, head_dim, device=device)) - - def forward(self, qkv): - """ - Forward pass simulating the unfused pattern (with bias) - """ - seq_len = qkv.shape[0] - cos = self.cos_weight.expand(1, seq_len, 1, self.head_dim) - sin = self.sin_weight.expand(1, seq_len, 1, self.head_dim) - - # Split QKV - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) - q_norm_out, _ = torch_npu.npu_rms_norm(q_by_head, self.q_weight, self.eps) - q_normed = q_norm_out + self.q_bias - q_flat = q_normed.view(q.shape) - q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, self.head_dim) - - k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) - k_norm_out, _ = torch_npu.npu_rms_norm(k_by_head, self.k_weight, self.eps) - k_normed = k_norm_out + self.k_bias - k_flat = k_normed.view(k.shape) - k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, self.head_dim) - - # Apply RoPE - q_rope, k_rope = torch_npu.npu_apply_rotary_pos_emb( - q_reshape, k_reshape, cos, sin - ) - - return q_rope, k_rope, v - - def ops_in_model_before(self) -> List[OpOverload]: - """Return the list of expected operators BEFORE fusion.""" - return [ - torch.ops.npu.npu_apply_rotary_pos_emb.default, - torch.ops.aten.add.Tensor - ] - - def ops_in_model_after(self) -> List[OpOverload]: - """Return the list of expected operators AFTER successful fusion.""" - return [torch.ops.vllm.qkv_rmsnorm_rope.default] - - -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("seq_len", [1, 128]) -@pytest.mark.parametrize("eps", [1e-5, 1e-6]) -@pytest.mark.parametrize("with_bias", [False]) -def test_qknorm_rope_fusion( - dtype: torch.dtype, - seq_len: int, - eps: float, - with_bias: bool, -): - """ - End-to-end test for QKV split + RMSNorm + RoPE fusion. - Tests both with and without bias versions. - """ - torch.set_default_dtype(dtype) - torch.manual_seed(42) - - # Model parameters - head_dim = 128 - num_heads = 32 - num_kv_heads = 8 - - vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype)) - - with vllm.config.set_current_vllm_config(vllm_config): - # Create backend with the fusion pass - backend = TestBackend(custom_passes=[QKNormRopeFusionPass(vllm_config)]) - - # Create appropriate test model based on bias flag - if with_bias: - model = TestQKNormRopeModelWithBias( - head_dim=head_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - eps=eps, - device="npu" - ) - else: - model = TestQKNormRopeModelNoBias( - head_dim=head_dim, - num_heads=num_heads, - num_kv_heads=num_kv_heads, - eps=eps, - device="npu" - ) - - model = model.to("npu") - - # Create test input - qkv_size = num_heads * head_dim + 2 * num_kv_heads * head_dim - x = torch.rand( - seq_len, - qkv_size, - device="npu", - dtype=dtype, - requires_grad=False - ) - - # Run unfused model - result_unfused = model(x) - print(f"Unfused result shapes: {[t.shape for t in result_unfused]}") - - # Compile with fusion - with torch.no_grad(): - model_fused = torch.compile(model, backend=backend) - result_fused = model_fused(x) - print(f"Fused result shapes: {[t.shape for t in result_fused]}") - - print("=== Checking operator fusion ===") - backend.check_before_ops(model.ops_in_model_before()) - backend.check_after_ops(model.ops_in_model_after()) \ No newline at end of file From 4a745bd0f93e5ef18444c8a98f5a9a8c378cdb98 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 9 Dec 2025 06:45:57 +0000 Subject: [PATCH 15/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- .../linearnorm/split_qkv_rmsnorm_rope.py | 14 +++------- vllm_ascend/worker/model_runner_v1.py | 26 ++++++++++++------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 9d51e0a1ea2..e8130e85ed2 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -22,6 +22,8 @@ import triton.runtime.driver as driver from vllm.utils.torch_utils import direct_register_custom_op +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + @triton.jit def split_qkv_rmsnorm_rope_kernel( @@ -198,17 +200,6 @@ def split_qkv_rmsnorm_rope_kernel( output_offset += output_offset_step -kernels = {} - - -def get_npu_properties(): - device = torch.npu.current_device() - return driver.active.utils.get_device_properties(device) - - -num_vectorcore = get_npu_properties()["num_vectorcore"] - - def split_qkv_rmsnorm_rope_impl( input: torch.Tensor, sin: torch.Tensor, @@ -241,6 +232,7 @@ def split_qkv_rmsnorm_rope_impl( device=input.device, dtype=input.dtype) n_cols = kv_hidden_size // KV_BLOCK_SIZE + num_vectorcore = get_vectorcore_num() assert num_vectorcore % n_cols == 0 n_rows = num_vectorcore // n_cols BIAS = q_bias is not None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8baff88d5f5..e48252f29ed 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1445,8 +1445,10 @@ def execute_model( prefetch_stream=self.prefetch_stream, model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :maybe_padded_num_tokens] if self.cos is not None else None, - sin=self.sin[:, :maybe_padded_num_tokens] if self.sin is not None else None): + cos=self.cos[:, :maybe_padded_num_tokens] + if self.cos is not None else None, + sin=self.sin[:, :maybe_padded_num_tokens] + if self.sin is not None else None): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2148,16 +2150,20 @@ def dummy_drafter_compute_logits(hidden_states): self.drafter.model, "compute_logits"): return self.drafter.model.compute_logits( hidden_states[dummy_indices]) - + # initialize rope - cos_sin_cache=self.model.model.layers[ - self.model.model.start_layer].self_attn.rotary_emb.cos_sin_cache.index_select(0, positions) + cos_sin_cache = self.model.model.layers[ + self.model.model. + start_layer].self_attn.rotary_emb.cos_sin_cache.index_select( + 0, positions) last_dim = cos_sin_cache.size()[-1] cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( 1, 1, 2).chunk(2, dim=-2) # BSNH - self.cos[:, :num_tokens] = cos.view(1, -1, 1, last_dim).contiguous() - self.sin[:, :num_tokens] = sin.view(1, -1, 1, last_dim).contiguous() + self.cos[:, :num_tokens] = cos.view(1, -1, 1, + last_dim).contiguous() + self.sin[:, :num_tokens] = sin.view(1, -1, 1, + last_dim).contiguous() with set_ascend_forward_context( attn_metadata, @@ -2172,8 +2178,10 @@ def dummy_drafter_compute_logits(hidden_states): prefetch_stream=self.prefetch_stream, model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :num_tokens] if self.cos is not None else None, - sin=self.sin[:, :num_tokens] if self.sin is not None else None): + cos=self.cos[:, :num_tokens] + if self.cos is not None else None, + sin=self.sin[:, :num_tokens] + if self.sin is not None else None): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From 7f5ab088a9924a01c388fcac13543b036034bce4 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 9 Dec 2025 09:39:46 +0000 Subject: [PATCH 16/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index e8130e85ed2..5a0d69f2748 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -18,7 +18,7 @@ import torch import triton -import triton.language as tl +import triton.language as tl # type: ignore import triton.runtime.driver as driver from vllm.utils.torch_utils import direct_register_custom_op From a9cfb33e0dc8d365acac523468c8074f81c591c8 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 11 Dec 2025 02:52:41 +0000 Subject: [PATCH 17/40] fix eagle spec decode Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 76 +++++++++++++++++-- .../linearnorm/split_qkv_rmsnorm_rope.py | 2 +- vllm_ascend/spec_decode/eagle_proposer.py | 16 +++- vllm_ascend/worker/model_runner_v1.py | 4 + 4 files changed, 87 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 8133f933f79..3a81647dc32 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -29,6 +29,69 @@ from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, get_ascend_device_type) +# Currently, rope ops used on npu requires detached cos && sin as inputs. +# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. +# So we have to preprocess cos_sin_cache int cos && sin. In the future, +# we shall implement a new rope ops which accept cos_sin_cache as inputs. +_cos_sin_cache: Optional[torch.Tensor] = None +_cos_cache: Optional[torch.Tensor] = None +_sin_cache: Optional[torch.Tensor] = None +_cos: Optional[torch.Tensor] = None +_sin: Optional[torch.Tensor] = None + + +def _record_cos_sin_cache(cos_sin_cache): + global _cos_sin_cache + if _cos_sin_cache is not None: + return + _cos_sin_cache = cos_sin_cache + + +def initialize_cos_sin(vllm_config, dtype, device): + global _cos_cache + global _sin_cache + + head_dim = vllm_config.model_config.get_head_size() + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + _cos_cache = torch.ones(1, + max_num_batched_tokens, + 1, + head_dim, + dtype=dtype, + device=device) + _sin_cache = torch.zeros(1, + max_num_batched_tokens, + 1, + head_dim, + dtype=dtype, + device=device) + + +def update_cos_sin(positions): + global _cos_cache + global _sin_cache + global _cos + global _sin + + if _cos_sin_cache is None or \ + _cos_cache is None or \ + _sin_cache is None: + return + + num_tokens = positions.size(0) + _cos_cache[:, :num_tokens] = _cos_sin_cache.index_select( + 0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, + dim=-2)[0] + _sin_cache[:, :num_tokens] = _cos_sin_cache.index_select( + 0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, + dim=-2)[1] + _cos = _cos_cache[:, :num_tokens] + _sin = _sin_cache[:, :num_tokens] + + +def get_cos_sin(): + return _cos, _sin + def _custom_rotary_embedding_enabled(query, neox_style, head_size): return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op( @@ -65,8 +128,9 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if hasattr(self, "cos") and hasattr(self, "sin") and \ - self.cos is not None and self.sin is not None: + cos, sin = get_cos_sin() + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128 and cos is not None and sin is not None: # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. # This method requires head_size and rotary_dim equal 128 and neox_style is True query = query.contiguous().view(1, query.shape[0], -1, @@ -75,7 +139,7 @@ def _rope_forward_oot( # Although this function modifies in-place, please retain the function's return value. # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, forward_context.cos, forward_context.sin) + query, key, cos, sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) @@ -125,10 +189,9 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ) -> None: - self.cos = None - self.sin = None super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + _record_cos_sin_cache(self.cos_sin_cache) def forward_oot( self, @@ -162,8 +225,6 @@ def __init__( beta_fast: int = 32, beta_slow: int = 1, ) -> None: - self.cos = None - self.sin = None extra_kwargs = { "extrapolation_factor": extrapolation_factor, "attn_factor": attn_factor, @@ -172,6 +233,7 @@ def __init__( } super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, scaling_factor, dtype, **extra_kwargs) + _record_cos_sin_cache(self.cos_sin_cache) def forward_oot( self, diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 5a0d69f2748..fba0f31203d 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -18,7 +18,7 @@ import torch import triton -import triton.language as tl # type: ignore +import triton.language as tl # type: ignore import triton.runtime.driver as driver from vllm.utils.torch_utils import direct_register_custom_op diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 24a846d9bae..2430c52f273 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -25,6 +25,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType PADDING_SLOT_ID = -1 @@ -149,7 +150,7 @@ def dummy_run(self, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], + positions=positions, hidden_states=self.hidden_states[:num_tokens], ) dummy_compute_logits(self.hidden_states) @@ -340,12 +341,16 @@ def _propose( attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) + positions = self.positions[:num_input_tokens] + # update global cos, sin + update_cos_sin(positions) + with set_ascend_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], + positions=positions, hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] @@ -444,13 +449,18 @@ def _propose( attn_metadata.attn_mask = attn_mask # Run the model. + + positions = self.positions[:input_batch_size] + # update global cos, sin + update_cos_sin(positions) + with set_ascend_forward_context(attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], + positions=positions, hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e48252f29ed..7ff918be356 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -111,6 +111,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register +from vllm_ascend.ops.rotary_embedding import initialize_cos_sin, update_cos_sin from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.logits_processor import build_logitsprocs @@ -1145,6 +1146,9 @@ def _prepare_inputs( for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i + # update global cos, sin + update_cos_sin(positions) + if lmhead_tp_enable(): max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len logits_indices = nn.functional.pad( From e0c5139ff9c202ad64321d2c209ba9853678a5dc Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 11 Dec 2025 02:59:04 +0000 Subject: [PATCH 18/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 6 ------ vllm_ascend/worker/model_runner_v1.py | 26 ++------------------------ 2 files changed, 2 insertions(+), 30 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 8faa01fa9c5..b4343e76f0c 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -44,8 +44,6 @@ def set_ascend_forward_context( prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, weight_prefetch_method: Optional[WeightPrefetchMethod] = None, - cos: Optional[torch.Tensor] = None, - sin: Optional[torch.Tensor] = None, is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -130,10 +128,6 @@ def set_ascend_forward_context( forward_context.weight_prefetch_method = weight_prefetch_method forward_context.is_mtp_model = is_mtp_model - # initialize rope - forward_context.cos = cos - forward_context.sin = sin - if num_tokens is None and attn_metadata is not None: num_tokens = attn_metadata.num_actual_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7ff918be356..75290dd81be 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1448,11 +1448,7 @@ def execute_model( total_num_scheduled_tokens, prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :maybe_padded_num_tokens] - if self.cos is not None else None, - sin=self.sin[:, :maybe_padded_num_tokens] - if self.sin is not None else None): + weight_prefetch_method=self.weight_prefetch_method): self.maybe_setup_kv_connector(scheduler_output) hidden_states = self._generate_process_reqs_hidden_states( @@ -2155,20 +2151,6 @@ def dummy_drafter_compute_logits(hidden_states): return self.drafter.model.compute_logits( hidden_states[dummy_indices]) - # initialize rope - cos_sin_cache = self.model.model.layers[ - self.model.model. - start_layer].self_attn.rotary_emb.cos_sin_cache.index_select( - 0, positions) - last_dim = cos_sin_cache.size()[-1] - cos, sin = cos_sin_cache.reshape(-1, 2, last_dim // 2).repeat( - 1, 1, 2).chunk(2, dim=-2) - # BSNH - self.cos[:, :num_tokens] = cos.view(1, -1, 1, - last_dim).contiguous() - self.sin[:, :num_tokens] = sin.view(1, -1, 1, - last_dim).contiguous() - with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -2181,11 +2163,7 @@ def dummy_drafter_compute_logits(hidden_states): batch_descriptor=batch_descriptor, prefetch_stream=self.prefetch_stream, model_instance=self.model, - weight_prefetch_method=self.weight_prefetch_method, - cos=self.cos[:, :num_tokens] - if self.cos is not None else None, - sin=self.sin[:, :num_tokens] - if self.sin is not None else None): + weight_prefetch_method=self.weight_prefetch_method): hidden_states = self._generate_dummy_run_hidden_states( input_ids, positions, num_tokens_padded, intermediate_tensors, inputs_embeds) From fdb549cc2e335028e2efa6cb0d1386a8b9a50d49 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 11 Dec 2025 04:49:02 +0000 Subject: [PATCH 19/40] fix triton Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 558ff6b8102..aadb6164705 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -16,11 +16,15 @@ # import torch +from vllm.triton_utils import HAS_TRITON import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa -import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa + +if HAS_TRITON: + import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa + import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( From ead3622b98916f3cdb1eed9453ab5cefa438c99e Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 11 Dec 2025 06:13:57 +0000 Subject: [PATCH 20/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 1 - vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 3a81647dc32..ac228cccd1c 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,7 +20,6 @@ import torch import torch_npu -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index fba0f31203d..58e56ae5ee5 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -17,9 +17,8 @@ from typing import Optional import torch -import triton +import triton # type: ignore import triton.language as tl # type: ignore -import triton.runtime.driver as driver from vllm.utils.torch_utils import direct_register_custom_op from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num From 4a02ac5d2882d3427183b7d753eebe6c1311a85f Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Thu, 11 Dec 2025 08:28:17 +0000 Subject: [PATCH 21/40] install triton Signed-off-by: wxsIcey <1790571317@qq.com> --- .github/workflows/_e2e_test.yaml | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index f4c6a65e47b..1168218e014 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -68,6 +68,12 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . + - name: Install Ascend toolkit & triton_ascend + shell: bash -l {0} + run: | + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" + - name: Run vllm-project/vllm-ascend test env: VLLM_WORKER_MULTIPROC_METHOD: spawn @@ -164,6 +170,12 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . + - name: Install Ascend toolkit & triton_ascend + shell: bash -l {0} + run: | + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" + - name: Run vllm-project/vllm-ascend test (light) env: VLLM_WORKER_MULTIPROC_METHOD: spawn @@ -258,6 +270,12 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . + - name: Install Ascend toolkit & triton_ascend + shell: bash -l {0} + run: | + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" + - name: Run vllm-project/vllm-ascend test for V1 Engine working-directory: ./vllm-ascend env: @@ -269,12 +287,6 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16 pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py - - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) - shell: bash -l {0} - run: | - . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" - - name: Run vllm-project/vllm-ascend Qwen3 Next test working-directory: ./vllm-ascend shell: bash -el {0} From 9019f1646caed17541bc57da0750149aa141245d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 12 Dec 2025 07:18:30 +0000 Subject: [PATCH 22/40] fix ut Signed-off-by: wxsIcey <1790571317@qq.com> --- .github/workflows/_e2e_test.yaml | 9 +++- .../e2e/singlecard/test_aclgraph_accuracy.py | 42 +++++++++---------- vllm_ascend/ascend_config.py | 9 ++-- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 1168218e014..33c0a0e5c40 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -170,6 +170,14 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . + - name: Run vllm-project/vllm-ascend test (no triton) + env: + VLLM_WORKER_MULTIPROC_METHOD: spawn + VLLM_USE_MODELSCOPE: True + if: ${{ inputs.type == 'full' }} + run: | + pytest -sv --durations=0 tests/e2e/multicard/test_aclgraph_capture_replay.py + - name: Install Ascend toolkit & triton_ascend shell: bash -l {0} run: | @@ -191,7 +199,6 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv --durations=0 tests/e2e/multicard/test_quantization.py - pytest -sv --durations=0 tests/e2e/multicard/test_aclgraph_capture_replay.py pytest -sv --durations=0 tests/e2e/multicard/test_full_graph_mode.py pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel.py pytest -sv --durations=0 tests/e2e/multicard/test_expert_parallel.py diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index b1878862b9f..cc57a1c96ae 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -22,6 +22,7 @@ import os +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" import pytest from vllm import SamplingParams @@ -36,7 +37,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) -def test_output_between_eager_and_aclgraph( +def test_output_with_aclgraph( model: str, max_tokens: int, ) -> None: @@ -44,6 +45,19 @@ def test_output_between_eager_and_aclgraph( "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is" ] + vllm_aclgraph_qwen_answers = [ + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", + ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', + ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', + ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ] + + vllm_aclgraph_ds_answers = [ + '\nI am a 20 year old student from the UK. I am currently studying for a degree in English Literature and Creative Writing. I have a passion', + ' a man who has been in the public eye for decades. He has been a senator, a governor, and a businessman. He has also been married to the', + ' Paris, which is also the largest city in the country. The city is located on the River Seine and is known for its beautiful architecture, museums, and art', + ' here.\nThe future of AI is here.\nThe future of AI is here.\nThe future of AI is here.\nThe future of AI is' + ] sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": @@ -55,15 +69,6 @@ def test_output_between_eager_and_aclgraph( ) as runner: vllm_aclgraph_outputs = runner.model.generate( prompts, sampling_params) - - with VllmRunner( - model, - max_model_len=1024, - enforce_eager=True, - quantization="ascend", - ) as runner: - vllm_eager_outputs = runner.model.generate(prompts, - sampling_params) else: with VllmRunner( model, @@ -72,23 +77,16 @@ def test_output_between_eager_and_aclgraph( ) as runner: vllm_aclgraph_outputs = runner.model.generate( prompts, sampling_params) - - with VllmRunner( - model, - max_model_len=1024, - enforce_eager=True, - ) as runner: - vllm_eager_outputs = runner.model.generate(prompts, - sampling_params) vllm_aclgraph_outputs_list = [] for output in vllm_aclgraph_outputs: vllm_aclgraph_outputs_list.append( (output.outputs[0].index, output.outputs[0].text)) - vllm_eager_outputs_list = [] - for output in vllm_eager_outputs: - vllm_eager_outputs_list.append( - (output.outputs[0].index, output.outputs[0].text)) + vllm_eager_outputs_list = ([ + ([0], answer) for answer in vllm_aclgraph_ds_answers + ] if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8" else [ + ([0], answer) for answer in vllm_aclgraph_qwen_answers + ]) check_outputs_equal( outputs_0_lst=vllm_eager_outputs_list, diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 01f988859be..cd0146799e3 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -231,11 +231,10 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, - fuse_norm_quant: bool = True, - fuse_qknorm_rope: bool = True, - **kwargs - ): + def __init__(self, + fuse_norm_quant: bool = True, + fuse_qknorm_rope: bool = True, + **kwargs): """ Initialize the configuration. From 09019ab7d8a6f9331e6353d7e2c2d2ae55f05964 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 12 Dec 2025 07:46:46 +0000 Subject: [PATCH 23/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_aclgraph_accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index cc57a1c96ae..7df98db30fe 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -80,7 +80,7 @@ def test_output_with_aclgraph( vllm_aclgraph_outputs_list = [] for output in vllm_aclgraph_outputs: vllm_aclgraph_outputs_list.append( - (output.outputs[0].index, output.outputs[0].text)) + ([output.outputs[0].index], output.outputs[0].text)) vllm_eager_outputs_list = ([ ([0], answer) for answer in vllm_aclgraph_ds_answers From 5454d63e074d0fea3f6ac7d715dce2e8f0566ade Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Fri, 12 Dec 2025 08:54:59 +0000 Subject: [PATCH 24/40] fix ut Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/multicard/test_aclgraph_capture_replay.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py index e81b5615432..86a7e76c75a 100644 --- a/tests/e2e/multicard/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -28,7 +28,7 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type MODELS = [ - "Qwen/Qwen3-0.6B", + "facebook/opt-125m", "vllm-ascend/DeepSeek-V2-Lite-W8A8", ] From 360fcf59a26f039e2a09b2589f28ac1ec2eaeecb Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 14 Dec 2025 12:31:54 +0000 Subject: [PATCH 25/40] fix ut Signed-off-by: wxsIcey <1790571317@qq.com> --- tests/e2e/singlecard/test_aclgraph_accuracy.py | 2 +- vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 7df98db30fe..0892b7b3a4a 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -134,7 +134,7 @@ def test_output_between_eager_and_full_decode_only( ] vllm_aclgraph_qwen_answers = [ ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle formed by two random points on a square's perimeter is", + ' \n\nTo solve this problem, we can use the following approach: Let $ABCD$ be a unit square with coordinates $A(0,0), B', ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' ] diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index ea900c646fe..7907a6bc0f2 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -243,6 +243,11 @@ def __init__(self, vllm_config: VllmConfig): return layer = next(iter(attn_layers.values())) for epsilon in [1e-6, 1e-5]: + if layer.head_size != 128: + logging.debug( + "QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", + layer.head_size) + continue QKNormRopeFusionPattern(head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, From 129fcde6d7d59d784d7865b1c67e46dab89b638d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 14 Dec 2025 13:28:53 +0000 Subject: [PATCH 26/40] resolve conflict Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 75290dd81be..03d0de3c329 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -123,7 +123,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_enable_nz, - is_moe_model, lmhead_tp_enable, vllm_version_is) + is_moe_model, lmhead_tp_enable, vllm_version_is, is_vl_model) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -282,6 +282,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): set_cos_and_sin(vllm_config, self.max_num_reqs, self.uniform_decode_query_len, self.dtype, self.device) + if not is_vl_model(self.vllm_config + ) and not self.vllm_config.model_config.use_mla: + initialize_cos_sin(self.vllm_config, self.dtype, self.device) set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) @@ -2111,6 +2114,9 @@ def _dummy_run( else: positions = self.positions.gpu[:num_tokens_padded] + # update global cos, sin + update_cos_sin(positions) + if get_pp_group().is_first_rank: intermediate_tensors = None else: From 536ead415146a05f5364d4cbff7a9d3e13a53275 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Sun, 14 Dec 2025 13:47:11 +0000 Subject: [PATCH 27/40] change workflow Signed-off-by: wxsIcey <1790571317@qq.com> --- .github/workflows/_e2e_test.yaml | 3 ++- tests/e2e/multicard/test_aclgraph_capture_replay.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 33c0a0e5c40..a6dd24dffe4 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -170,7 +170,8 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . - - name: Run vllm-project/vllm-ascend test (no triton) + # This test doesn't require triton-ascend, we run it to avoid potential triton issues. + - name: Run vllm-project/vllm-ascend test (no triton-ascend) env: VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_USE_MODELSCOPE: True diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py index 86a7e76c75a..d36c97a3455 100644 --- a/tests/e2e/multicard/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -27,10 +27,8 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type -MODELS = [ - "facebook/opt-125m", - "vllm-ascend/DeepSeek-V2-Lite-W8A8", -] +# here we delete qwen3-0.6b, please add it when the test can be enabled when trion-ascend is supported. +MODELS = ["vllm-ascend/DeepSeek-V2-Lite-W8A8"] def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack: From 859923756067946ad6b2ef3660b10c266d128ae3 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 02:04:17 +0000 Subject: [PATCH 28/40] tiny fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index cd0146799e3..29dcef8d54b 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -242,9 +242,6 @@ def __init__(self, fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. When set to True, the system will optimize norm and quant operations. Default: True - fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. - Default: True - fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. Default: True **kwargs: Additional optional parameters for forward compatibility and configuration extension. From 5de88fc1bf4d4a9d8811ec4ee1073d06df2fa280 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 02:28:23 +0000 Subject: [PATCH 29/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/attention/mla_v1.py | 4 +- vllm_ascend/attention/sfa_v1.py | 4 +- vllm_ascend/ops/rotary_embedding.py | 113 +++++++++++++++++--------- vllm_ascend/worker/model_runner_v1.py | 8 +- 4 files changed, 82 insertions(+), 47 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 29643c0505b..e250fdbc0ad 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,7 +22,6 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, @@ -32,6 +31,7 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params, get_mtp_graph_params, update_graph_params_workspaces) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, @@ -531,7 +531,7 @@ def build( decode_metadata = None if num_decodes > 0: - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index f6b338a807c..a53e9423412 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -16,12 +16,12 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, @@ -187,7 +187,7 @@ def build( cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() assert self.cos_cache is not None and self.sin_cache is not None new_cos = self.cos_cache[input_positions][:, None, None] diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ac228cccd1c..6dd1a984d87 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,23 +20,82 @@ import torch import torch_npu +from vllm.config import CUDAGraphMode from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type) + get_ascend_device_type, is_vl_model) # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. # So we have to preprocess cos_sin_cache int cos && sin. In the future, # we shall implement a new rope ops which accept cos_sin_cache as inputs. +# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin +# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from +# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by +# attn_metadata. This causes that rope in GQA models must pass cos && sin +# by different approaches. +_cos_mla: Optional[torch.Tensor] = None +_sin_mla: Optional[torch.Tensor] = None _cos_sin_cache: Optional[torch.Tensor] = None -_cos_cache: Optional[torch.Tensor] = None -_sin_cache: Optional[torch.Tensor] = None _cos: Optional[torch.Tensor] = None _sin: Optional[torch.Tensor] = None +_cos_slice: Optional[torch.Tensor] = None +_sin_slice: Optional[torch.Tensor] = None + + +def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, + device): + global _cos_mla + global _sin_mla + global _cos + global _sin + + if _cos_mla is not None or \ + _sin_mla is not None or \ + _cos is not None or \ + _sin is not None: + return + + compilation_config = vllm_config.compilation_config + model_config = vllm_config.model_config + head_dim = model_config.get_head_size() + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + + if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = model_config.hf_text_config.qk_rope_head_dim + _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla: + _cos = torch.ones(1, + max_num_batched_tokens, + 1, + head_dim, + dtype=dtype, + device=device) + _sin = torch.zeros(1, + max_num_batched_tokens, + 1, + head_dim, + dtype=dtype, + device=device) + + +def get_cos_and_sin_mla(): + return _cos_mla, _sin_mla def _record_cos_sin_cache(cos_sin_cache): @@ -46,50 +105,28 @@ def _record_cos_sin_cache(cos_sin_cache): _cos_sin_cache = cos_sin_cache -def initialize_cos_sin(vllm_config, dtype, device): - global _cos_cache - global _sin_cache - - head_dim = vllm_config.model_config.get_head_size() - max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - _cos_cache = torch.ones(1, - max_num_batched_tokens, - 1, - head_dim, - dtype=dtype, - device=device) - _sin_cache = torch.zeros(1, - max_num_batched_tokens, - 1, - head_dim, - dtype=dtype, - device=device) - - def update_cos_sin(positions): - global _cos_cache - global _sin_cache global _cos global _sin + global _cos_slice + global _sin_slice if _cos_sin_cache is None or \ - _cos_cache is None or \ - _sin_cache is None: + _cos is None or \ + _sin is None: return num_tokens = positions.size(0) - _cos_cache[:, :num_tokens] = _cos_sin_cache.index_select( - 0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, - dim=-2)[0] - _sin_cache[:, :num_tokens] = _cos_sin_cache.index_select( - 0, positions).view(num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, - dim=-2)[1] - _cos = _cos_cache[:, :num_tokens] - _sin = _sin_cache[:, :num_tokens] + _cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( + num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0] + _sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( + num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1] + _cos_slice = _cos[:, :num_tokens] + _sin_slice = _sin[:, :num_tokens] -def get_cos_sin(): - return _cos, _sin +def get_cos_and_sin_slice(): + return _cos_slice, _sin_slice def _custom_rotary_embedding_enabled(query, neox_style, head_size): @@ -127,7 +164,7 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - cos, sin = get_cos_sin() + cos, sin = get_cos_and_sin_slice() if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ -1] == 128 and cos is not None and sin is not None: # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 03d0de3c329..46d12a4e2a7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -88,7 +88,7 @@ get_mc2_tokens_capacity, select_moe_comm_method, set_ascend_forward_context, - set_cos_and_sin, set_mc2_mask, + set_mc2_mask, set_mc2_tokens_capacity) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -111,7 +111,8 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register -from vllm_ascend.ops.rotary_embedding import initialize_cos_sin, update_cos_sin +from vllm_ascend.ops.rotary_embedding import (initialize_cos_sin, + set_cos_and_sin, update_cos_sin) from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.logits_processor import build_logitsprocs @@ -282,9 +283,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): set_cos_and_sin(vllm_config, self.max_num_reqs, self.uniform_decode_query_len, self.dtype, self.device) - if not is_vl_model(self.vllm_config - ) and not self.vllm_config.model_config.use_mla: - initialize_cos_sin(self.vllm_config, self.dtype, self.device) set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) From fe8a177ae285d5a4da34d914e9260f43e0645f2b Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 02:39:30 +0000 Subject: [PATCH 30/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 46d12a4e2a7..d22d78a0cb5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -111,8 +111,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register -from vllm_ascend.ops.rotary_embedding import (initialize_cos_sin, - set_cos_and_sin, update_cos_sin) +from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.logits_processor import build_logitsprocs From 2c4371227c7ec61b308536906362e8e96b884aa4 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 06:55:35 +0000 Subject: [PATCH 31/40] fix qwen3 next Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 6dd1a984d87..dcf54da02da 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -80,16 +80,21 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, dtype=dtype, device=device) elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla: + rope_dim = model_config.get_head_size() + # For models using partial rope like Qwen3-Next. + if hasattr(model_config.hf_text_config, "partial_rotary_factor"): + rope_dim = int(rope_dim * + model_config.hf_text_config.partial_rotary_factor) _cos = torch.ones(1, max_num_batched_tokens, 1, - head_dim, + rope_dim, dtype=dtype, device=device) _sin = torch.zeros(1, max_num_batched_tokens, 1, - head_dim, + rope_dim, dtype=dtype, device=device) From 519aa974d5d596752c34042c1d9a6f51f5ffed8c Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 08:44:09 +0000 Subject: [PATCH 32/40] fix eagle Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ops/rotary_embedding.py | 1 - vllm_ascend/spec_decode/eagle_proposer.py | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index dcf54da02da..7f86047028f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -62,7 +62,6 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, compilation_config = vllm_config.compilation_config model_config = vllm_config.model_config - head_dim = model_config.get_head_size() max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 2430c52f273..573af0e64f6 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -150,7 +150,6 @@ def dummy_run(self, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], - positions=positions, hidden_states=self.hidden_states[:num_tokens], ) dummy_compute_logits(self.hidden_states) @@ -350,7 +349,7 @@ def _propose( num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], - positions=positions, + positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] @@ -450,9 +449,8 @@ def _propose( attn_metadata.attn_mask = attn_mask # Run the model. - positions = self.positions[:input_batch_size] # update global cos, sin - update_cos_sin(positions) + update_cos_sin(self.positions[:input_batch_size]) with set_ascend_forward_context(attn_metadata, self.vllm_config, @@ -460,7 +458,7 @@ def _propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], - positions=positions, + positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] From 5e22e20d4879d62de9b5dd0b002fd8a3089e3920 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 08:46:08 +0000 Subject: [PATCH 33/40] fix eagle Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/spec_decode/eagle_proposer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 573af0e64f6..74e846904c8 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -150,6 +150,7 @@ def dummy_run(self, num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], ) dummy_compute_logits(self.hidden_states) From 5e6a838e2129357d560964b892c5d6347423ff8d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 09:44:10 +0000 Subject: [PATCH 34/40] fix mla cp Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/attention/mla_cp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 6aa3848dc72..980f3c5050a 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -16,7 +16,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import MLAAttentionSpec -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, @@ -29,6 +28,7 @@ wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, reach_layer_for_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -286,7 +286,7 @@ def build( decode_metadata = None if num_decodes > 0: - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist() From 993744a2fb8dfdf64332a4b268fc716d8c028577 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Mon, 15 Dec 2025 13:50:53 +0000 Subject: [PATCH 35/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d22d78a0cb5..3daf67dc76f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -123,7 +123,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_enable_nz, - is_moe_model, lmhead_tp_enable, vllm_version_is, is_vl_model) + is_moe_model, lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: From 31aa1b2930468a00783cb2d01a43185ef116165d Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 16 Dec 2025 02:11:06 +0000 Subject: [PATCH 36/40] fix e2e Signed-off-by: wxsIcey <1790571317@qq.com> --- .../e2e/singlecard/test_aclgraph_accuracy.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 0892b7b3a4a..0238cde21df 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -28,6 +28,7 @@ from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal +from vllm_ascend.utils import vllm_version_is MODELS = [ "Qwen/Qwen3-0.6B", @@ -45,12 +46,20 @@ def test_output_with_aclgraph( "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is" ] - vllm_aclgraph_qwen_answers = [ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", - ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', - ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', - ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' - ] + if vllm_version_is("0.12.0"): + vllm_aclgraph_qwen_answers = [ + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", + ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', + ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', + ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ] + else: + vllm_aclgraph_qwen_answers = [ + " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the", + ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', + ' Paris. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of China is Beijing. The capital of Japan is Tokyo. The capital', + ' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' + ] vllm_aclgraph_ds_answers = [ '\nI am a 20 year old student from the UK. I am currently studying for a degree in English Literature and Creative Writing. I have a passion', @@ -132,12 +141,18 @@ def test_output_between_eager_and_full_decode_only( 'and $x^2 + cx + b = 0$ also have a common real root.' 'Compute the sum $a + b + c$.') ] - vllm_aclgraph_qwen_answers = [ - ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - ' \n\nTo solve this problem, we can use the following approach: Let $ABCD$ be a unit square with coordinates $A(0,0), B', - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' - ] - + if vllm_version_is("0.12.0"): + vllm_aclgraph_qwen_answers = [ + ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', + ' \n\nTo solve this problem, we can use the following approach: Let $ABCD$ be a unit square with coordinates $A(0,0), B', + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' + ] + else: + vllm_aclgraph_qwen_answers = [ + ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', + ' \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle formed by two random points on a square\'s perimeter is', + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' + ] vllm_aclgraph_ds_answers = [ '\n\nSelect an assignment template', '\n\nSelect an assignment template', From 01814e5e86bfb46aa3bc4143ec998eecc7f345a4 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 16 Dec 2025 06:40:56 +0000 Subject: [PATCH 37/40] recover e2e Signed-off-by: wxsIcey <1790571317@qq.com> --- .github/workflows/_e2e_test.yaml | 34 ++------- .../multicard/test_aclgraph_capture_replay.py | 6 +- .../e2e/singlecard/test_aclgraph_accuracy.py | 73 ++++++++----------- vllm_ascend/ascend_config.py | 7 +- 4 files changed, 45 insertions(+), 75 deletions(-) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index a6dd24dffe4..f4c6a65e47b 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -68,12 +68,6 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . - - name: Install Ascend toolkit & triton_ascend - shell: bash -l {0} - run: | - . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" - - name: Run vllm-project/vllm-ascend test env: VLLM_WORKER_MULTIPROC_METHOD: spawn @@ -170,21 +164,6 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . - # This test doesn't require triton-ascend, we run it to avoid potential triton issues. - - name: Run vllm-project/vllm-ascend test (no triton-ascend) - env: - VLLM_WORKER_MULTIPROC_METHOD: spawn - VLLM_USE_MODELSCOPE: True - if: ${{ inputs.type == 'full' }} - run: | - pytest -sv --durations=0 tests/e2e/multicard/test_aclgraph_capture_replay.py - - - name: Install Ascend toolkit & triton_ascend - shell: bash -l {0} - run: | - . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" - - name: Run vllm-project/vllm-ascend test (light) env: VLLM_WORKER_MULTIPROC_METHOD: spawn @@ -200,6 +179,7 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv --durations=0 tests/e2e/multicard/test_quantization.py + pytest -sv --durations=0 tests/e2e/multicard/test_aclgraph_capture_replay.py pytest -sv --durations=0 tests/e2e/multicard/test_full_graph_mode.py pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel.py pytest -sv --durations=0 tests/e2e/multicard/test_expert_parallel.py @@ -278,12 +258,6 @@ jobs: pip install -r requirements-dev.txt pip install -v -e . - - name: Install Ascend toolkit & triton_ascend - shell: bash -l {0} - run: | - . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" - - name: Run vllm-project/vllm-ascend test for V1 Engine working-directory: ./vllm-ascend env: @@ -295,6 +269,12 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16 pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py + - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) + shell: bash -l {0} + run: | + . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.whl" + - name: Run vllm-project/vllm-ascend Qwen3 Next test working-directory: ./vllm-ascend shell: bash -el {0} diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py index d36c97a3455..e81b5615432 100644 --- a/tests/e2e/multicard/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -27,8 +27,10 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type -# here we delete qwen3-0.6b, please add it when the test can be enabled when trion-ascend is supported. -MODELS = ["vllm-ascend/DeepSeek-V2-Lite-W8A8"] +MODELS = [ + "Qwen/Qwen3-0.6B", + "vllm-ascend/DeepSeek-V2-Lite-W8A8", +] def _install_spies(counters: dict[str, Any]) -> contextlib.ExitStack: diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 0238cde21df..e217f0bbcaa 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -22,13 +22,11 @@ import os -os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" import pytest from vllm import SamplingParams from tests.e2e.conftest import VllmRunner from tests.e2e.model_utils import check_outputs_equal -from vllm_ascend.utils import vllm_version_is MODELS = [ "Qwen/Qwen3-0.6B", @@ -38,7 +36,7 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) -def test_output_with_aclgraph( +def test_output_between_eager_and_aclgraph( model: str, max_tokens: int, ) -> None: @@ -46,27 +44,6 @@ def test_output_with_aclgraph( "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is" ] - if vllm_version_is("0.12.0"): - vllm_aclgraph_qwen_answers = [ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I want to know if there are any", - ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', - ' Paris. The capital of France is also the capital of the Republic of France. The capital of France is also the capital of the European Union. The capital of', - ' not just a technological frontier but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' - ] - else: - vllm_aclgraph_qwen_answers = [ - " Lina. I'm a 22-year-old student from China. I'm interested in studying in the US. I'm looking for a job in the", - ' the same as the president of the United Nations. This is because the president of the United States is the same as the president of the United Nations. The president', - ' Paris. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of China is Beijing. The capital of Japan is Tokyo. The capital', - ' not just a technological challenge but a profound transformation of how we live, work, and interact with the world. As we stand at the intersection of artificial intelligence and' - ] - - vllm_aclgraph_ds_answers = [ - '\nI am a 20 year old student from the UK. I am currently studying for a degree in English Literature and Creative Writing. I have a passion', - ' a man who has been in the public eye for decades. He has been a senator, a governor, and a businessman. He has also been married to the', - ' Paris, which is also the largest city in the country. The city is located on the River Seine and is known for its beautiful architecture, museums, and art', - ' here.\nThe future of AI is here.\nThe future of AI is here.\nThe future of AI is here.\nThe future of AI is' - ] sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": @@ -78,6 +55,15 @@ def test_output_with_aclgraph( ) as runner: vllm_aclgraph_outputs = runner.model.generate( prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + quantization="ascend", + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, + sampling_params) else: with VllmRunner( model, @@ -86,16 +72,23 @@ def test_output_with_aclgraph( ) as runner: vllm_aclgraph_outputs = runner.model.generate( prompts, sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, + sampling_params) vllm_aclgraph_outputs_list = [] for output in vllm_aclgraph_outputs: vllm_aclgraph_outputs_list.append( - ([output.outputs[0].index], output.outputs[0].text)) + (output.outputs[0].index, output.outputs[0].text)) - vllm_eager_outputs_list = ([ - ([0], answer) for answer in vllm_aclgraph_ds_answers - ] if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8" else [ - ([0], answer) for answer in vllm_aclgraph_qwen_answers - ]) + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) check_outputs_equal( outputs_0_lst=vllm_eager_outputs_list, @@ -141,18 +134,12 @@ def test_output_between_eager_and_full_decode_only( 'and $x^2 + cx + b = 0$ also have a common real root.' 'Compute the sum $a + b + c$.') ] - if vllm_version_is("0.12.0"): - vllm_aclgraph_qwen_answers = [ - ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - ' \n\nTo solve this problem, we can use the following approach: Let $ABCD$ be a unit square with coordinates $A(0,0), B', - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' - ] - else: - vllm_aclgraph_qwen_answers = [ - ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', - ' \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle formed by two random points on a square\'s perimeter is', - ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' - ] + vllm_aclgraph_qwen_answers = [ + ' \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the', + " \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle formed by two random points on a square's perimeter is", + ' \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can' + ] + vllm_aclgraph_ds_answers = [ '\n\nSelect an assignment template', '\n\nSelect an assignment template', @@ -223,4 +210,4 @@ def test_aclgraph_enable(): # after check_and_update_config, mode should be VLLM_COMPILE and piecewise cudagraph NPUPlatform.check_and_update_config(VllmConfig) assert VllmConfig.compilation_config.mode == CompilationMode.VLLM_COMPILE - assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE \ No newline at end of file diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 29dcef8d54b..e9b9b1f6200 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -17,6 +17,7 @@ from uuid import uuid4 from vllm.logger import logger +from vllm.triton_utils import HAS_TRITON def check_kv_extra_config(vllm_config): @@ -233,7 +234,7 @@ class AscendCompilationConfig: def __init__(self, fuse_norm_quant: bool = True, - fuse_qknorm_rope: bool = True, + fuse_qknorm_rope: bool = False, **kwargs): """ Initialize the configuration. @@ -243,11 +244,11 @@ def __init__(self, When set to True, the system will optimize norm and quant operations. Default: True fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. - Default: True + Default: False **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.fuse_norm_quant = fuse_norm_quant - self.fuse_qknorm_rope = fuse_qknorm_rope + self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope class XliteGraphConfig: From 91435210d058743d282edddea68ac7aaa7a72440 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 16 Dec 2025 10:09:11 +0000 Subject: [PATCH 38/40] resolve conflict Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/ascend_forward_context.py | 31 ----------------------- vllm_ascend/spec_decode/eagle_proposer.py | 3 +++ 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index b4343e76f0c..b00bdabced6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -209,37 +209,6 @@ def get_mc2_mask(): return _reserved_mc2_mask -def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, - device): - global _cos - global _sin - if _cos is not None: - return - compilation_config = vllm_config.compilation_config - model_config = vllm_config.model_config - if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos = torch.ones(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin = torch.zeros(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - else: - _cos = None - _sin = None - - -def get_cos_and_sin(): - return _cos, _sin - - def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig) -> Optional[MoECommType]: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 74e846904c8..ebb623c75a6 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -144,6 +144,9 @@ def dummy_run(self, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None): + # update global cos, sin + update_cos_sin(self.positions[:num_tokens]) + with set_ascend_forward_context(None, self.vllm_config, in_profile_run=True, From dd4617fe31e0dc4e481bbb043b82b3004d7d1321 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 16 Dec 2025 14:27:11 +0000 Subject: [PATCH 39/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- .../passes/qknorm_rope_fusion_pass.py | 23 +++++++++++++++---- vllm_ascend/worker/model_runner_v1.py | 10 ++++---- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index 7907a6bc0f2..d0f1aa53296 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -29,7 +29,13 @@ class QKNormRopeFusionPattern: - def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): + def __init__(self, + vllm_config, + head_dim, + num_heads, + num_kv_heads, + eps=1e-6): + self.vllm_config = vllm_config self.head_dim = head_dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads @@ -121,14 +127,19 @@ def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, class QKNormRopeFusionPatternWithBias: - def __init__(self, head_dim, num_heads, num_kv_heads, eps=1e-6): + def __init__(self, + vllm_config, + head_dim, + num_heads, + num_kv_heads, + eps=1e-6): self.head_dim = head_dim self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.eps = eps - vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config self.device = vllm_config.device_config.device if vllm_config.device_config else None def get_inputs(self): @@ -248,13 +259,15 @@ def __init__(self, vllm_config: VllmConfig): "QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", layer.head_size) continue - QKNormRopeFusionPattern(head_dim=layer.head_size, + QKNormRopeFusionPattern(vllm_config=vllm_config, + head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon).register( self.pattern_match_passes) - QKNormRopeFusionPatternWithBias(head_dim=layer.head_size, + QKNormRopeFusionPatternWithBias(vllm_config=vllm_config, + head_dim=layer.head_size, num_heads=layer.num_heads, num_kv_heads=layer.num_kv_heads, eps=epsilon).register( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3daf67dc76f..5a82c021b3b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -84,12 +84,6 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import (MoECommType, - get_mc2_tokens_capacity, - select_moe_comm_method, - set_ascend_forward_context, - set_mc2_mask, - set_mc2_tokens_capacity) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, @@ -126,6 +120,10 @@ is_moe_model, lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch +from vllm_ascend.ascend_forward_context import ( # isort: skip + MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, + set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity) + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput From d90be755d652dca1100f7f541b7fa68ef927d3b3 Mon Sep 17 00:00:00 2001 From: wxsIcey <1790571317@qq.com> Date: Tue, 16 Dec 2025 14:32:56 +0000 Subject: [PATCH 40/40] fix Signed-off-by: wxsIcey <1790571317@qq.com> --- vllm_ascend/spec_decode/eagle_proposer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index ebb623c75a6..ff662b728ba 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -343,11 +343,9 @@ def _propose( builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) - - positions = self.positions[:num_input_tokens] # update global cos, sin - update_cos_sin(positions) - + update_cos_sin(self.positions[:num_input_tokens]) + with set_ascend_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens):