diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index 8090ef19cb0..7c2230e82aa 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -25,6 +25,8 @@ from vllm.config.compilation import Range from vllm.logger import logger +from vllm_ascend.ops.rotary_embedding import get_rope_dim + class QKNormRopeFusionPattern: @@ -252,12 +254,16 @@ def __init__(self, vllm_config: VllmConfig): ) return layer = next(iter(attn_layers.values())) + rope_dim = get_rope_dim(vllm_config) + if layer.head_size != 128 or rope_dim != layer.head_size: + logger.debug( + f"Currently, QKNorm and Rope fusion is only supported where" + f"rotary_dim == head_size and head_size == 128. But rotary_dim" + f"is {rope_dim} and head_size is {layer.head_size}. Therefore" + f"the fusion is skipped.") + return + for epsilon in [1e-6, 1e-5]: - if layer.head_size != 128: - logger.debug( - "QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", - layer.head_size) - continue QKNormRopeFusionPattern(vllm_config=vllm_config, head_dim=layer.head_size, num_heads=layer.num_heads, diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 79bfa0a2549..b9b94e12b8d 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -36,6 +36,24 @@ from vllm_ascend.ops.triton.rope import rope_forward_triton + +def get_rope_dim(vllm_config): + model_config = vllm_config.model_config + + if model_config.use_mla: + rope_dim = model_config.hf_text_config.qk_rope_head_dim + else: + rope_dim = model_config.get_head_size() + # For models using partial rope like Qwen3-Next. + if hasattr(model_config.hf_text_config, "partial_rotary_factor"): + rope_dim = int(rope_dim * + model_config.hf_text_config.partial_rotary_factor) + elif hasattr(model_config.hf_text_config, "rotary_dim"): + rope_dim = int(model_config.hf_text_config.rotary_dim) + + return rope_dim + + # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. # So we have to preprocess cos_sin_cache int cos && sin. In the future, @@ -73,7 +91,7 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens if model_config.use_mla: - rope_dim = model_config.hf_text_config.qk_rope_head_dim + rope_dim = get_rope_dim(vllm_config) _cos_mla = torch.ones(max_num_batched_tokens, 1, 1, @@ -87,11 +105,7 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, dtype=dtype, device=device) elif not is_vl_model(vllm_config) and has_rope(vllm_config): - rope_dim = model_config.get_head_size() - # For models using partial rope like Qwen3-Next. - if hasattr(model_config.hf_text_config, "partial_rotary_factor"): - rope_dim = int(rope_dim * - model_config.hf_text_config.partial_rotary_factor) + rope_dim = get_rope_dim(vllm_config) _cos = torch.ones(1, max_num_batched_tokens, 1,