Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from vllm.config.compilation import Range
from vllm.logger import logger

from vllm_ascend.ops.rotary_embedding import get_rope_dim


class QKNormRopeFusionPattern:

Expand Down Expand Up @@ -252,12 +254,16 @@ def __init__(self, vllm_config: VllmConfig):
)
return
layer = next(iter(attn_layers.values()))
rope_dim = get_rope_dim(vllm_config)
if layer.head_size != 128 or rope_dim != layer.head_size:
logger.debug(
f"Currently, QKNorm and Rope fusion is only supported where"
f"rotary_dim == head_size and head_size == 128. But rotary_dim"
f"is {rope_dim} and head_size is {layer.head_size}. Therefore"
f"the fusion is skipped.")
return

for epsilon in [1e-6, 1e-5]:
if layer.head_size != 128:
logger.debug(
"QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128",
layer.head_size)
continue
QKNormRopeFusionPattern(vllm_config=vllm_config,
head_dim=layer.head_size,
num_heads=layer.num_heads,
Expand Down
26 changes: 20 additions & 6 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,24 @@

from vllm_ascend.ops.triton.rope import rope_forward_triton


def get_rope_dim(vllm_config):
model_config = vllm_config.model_config

if model_config.use_mla:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
else:
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
elif hasattr(model_config.hf_text_config, "rotary_dim"):
rope_dim = int(model_config.hf_text_config.rotary_dim)

return rope_dim


# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
# So we have to preprocess cos_sin_cache int cos && sin. In the future,
Expand Down Expand Up @@ -73,7 +91,7 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens

if model_config.use_mla:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
rope_dim = get_rope_dim(vllm_config)
_cos_mla = torch.ones(max_num_batched_tokens,
1,
1,
Expand All @@ -87,11 +105,7 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim *
model_config.hf_text_config.partial_rotary_factor)
rope_dim = get_rope_dim(vllm_config)
_cos = torch.ones(1,
max_num_batched_tokens,
1,
Expand Down