Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128]
ROPE_DIMS = [64, 128]
EPS = [1e-6]
DTYPES = [torch.bfloat16]
SEEDS = [0]
Expand All @@ -23,19 +24,26 @@ def custom_rope(q, k, sin, cos):
rotary_dim = sin.shape[-1]
sin = sin.to(torch.float32)
cos = cos.to(torch.float32)
x1 = q[..., :rotary_dim // 2]
x2 = q[..., rotary_dim // 2:]
q_rot = q[..., :rotary_dim]
k_rot = k[..., :rotary_dim]
q_pass = q[..., rotary_dim:]
k_pass = k[..., rotary_dim:]

x1 = q_rot[..., :rotary_dim // 2]
x2 = q_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = q * cos
res1 = mul1 + mul2
mul2 = q_rot * cos
q_rot = mul1 + mul2
res1 = torch.cat([q_rot, q_pass], dim=-1)

x1 = k[..., :rotary_dim // 2]
x2 = k[..., rotary_dim // 2:]
x1 = k_rot[..., :rotary_dim // 2]
x2 = k_rot[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k * cos
res2 = mul1 + mul2
mul2 = k_rot * cos
k_rot = mul1 + mul2
res2 = torch.cat([k_rot, k_pass], dim=-1)
return res1, res2


Expand Down Expand Up @@ -64,9 +72,10 @@ def rms_norm(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
head_size, eps, dtype, seed, device, rope_dim):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
Expand All @@ -81,7 +90,7 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
k_weight = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu()
[max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
Expand Down Expand Up @@ -141,10 +150,11 @@ def test_split_qkv_rmsnorm_rope(max_position_embeddings, num_tokens, num_q_heads
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("rope_dim", ROPE_DIMS)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
seed, device, rope_dim):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
Expand All @@ -161,7 +171,7 @@ def test_split_qkv_rmsnorm_rope_with_bias(max_position_embeddings, num_tokens, n
k_bias = torch.randn(head_size, dtype=dtype, device=device)
cos_sin_cache = torch.from_numpy(
np.random.uniform(0, 1,
[max_position_embeddings, head_size])).to(dtype).npu()
[max_position_embeddings, rope_dim])).to(dtype).npu()
positions = torch.randint(low=0, high=max_position_embeddings, size=(num_tokens,), dtype=torch.int64, device=device)
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

import numpy as np
import pytest
import torch
import torch.nn as nn
Expand All @@ -16,6 +17,8 @@
)
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton

MAX_POSITION_EMBEDDING = 262144


def find_op(gm, op_default):
return any(node.op == "call_function" and node.target == op_default for node in gm.graph.nodes)
Expand Down Expand Up @@ -207,8 +210,10 @@ def test_rmsnorm_quant_fusion(
model = model.to("npu")
seq_len = 5
qkv = torch.randn(seq_len, qkv_size, device="npu", dtype=dtype)
cos = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
sin = torch.randn(1, seq_len, 1, head_dim, device="npu", dtype=dtype)
cos_sin_cache = torch.from_numpy(np.random.uniform(0, 1, [MAX_POSITION_EMBEDDING, head_dim])).to(dtype).npu()
positions = torch.randint(
low=0, high=MAX_POSITION_EMBEDDING, size=(num_tokens,), dtype=torch.int64, device="npu"
)

with torch.no_grad():
original_optimize = torchair.npu_fx_compiler._optimize_fx
Expand All @@ -218,6 +223,6 @@ def test_rmsnorm_quant_fusion(

compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=True)

compiled_model(qkv, cos, sin)
compiled_model(qkv, cos_sin_cache, positions)

torchair.npu_fx_compiler._optimize_fx = original_optimize
4 changes: 2 additions & 2 deletions tests/e2e/singlecard/test_aclgraph_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
prompts=PROMPTS_LONG,
golden_answers=[
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
],
)
Expand All @@ -95,7 +95,7 @@
prompts=PROMPTS_LONG,
golden_answers=[
" \n\nTo solve this problem, we need to use the Law of Sines and Law of Cosines. Let me start by drawing triangle $ABC$ with the",
" \n\nTo solve this problem, we can use the fact that the expected value of the area of a triangle with vertices on a square can be calculated by integrating over",
" \n\nTo solve this problem, we can use the following approach: Let $P$ be the perimeter of the square. Then, the expected value of the area",
" \n\nTo solve this problem, we can use the following approach: Let $ \\alpha $ be the common real root of the two equations. Then, we can",
],
)
Expand Down
Loading
Loading