Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ jobs:
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
# the test separately.

pytest -sv --durations=0 tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py
pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py
Expand Down
214 changes: 214 additions & 0 deletions tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import gc

import numpy as np
import pytest
import torch

import vllm_ascend.ops.register_custom_ops # noqa
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton

NUM_TOKENS = [1, 4, 8, 16, 1024]
NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)]
HEAD_SIZES = [128]
EPS = [1e-6]
DTYPES = [torch.bfloat16]
SEEDS = [0]
DEVICES = [f"npu:{0}"]
DEFAULT_ATOL = 5e-2
DEFAULT_RTOL = 5e-3


def custom_rope(q, k, sin, cos):
rotary_dim = sin.shape[-1]
sin = sin.to(torch.float32)
cos = cos.to(torch.float32)
x1 = q[..., :rotary_dim // 2]
x2 = q[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = q * cos
res1 = mul1 + mul2

x1 = k[..., :rotary_dim // 2]
x2 = k[..., rotary_dim // 2:]
cat_x = torch.cat([-x2, x1], axis=-1)
mul1 = cat_x * sin
mul2 = k * cos
res2 = mul1 + mul2
return res1, res2


def rms_norm(
input,
norm_weight,
eps,
norm_bias=None,
):
input = input.to(torch.float32)
norm_weight = norm_weight.to(torch.float32)
reciprocal_std = 1 / torch.sqrt(
torch.mean(input**2, axis=-1, keepdims=True) + eps)
out = input * reciprocal_std * norm_weight
if norm_bias is not None:
norm_bias = norm_bias.to(torch.float32)
out = out + norm_bias
return out


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()

q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
cos=cos,
sin=sin)

# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps)
_k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps)
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)

# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)

# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()

q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
q_bias = torch.randn(head_size, dtype=dtype, device=device)
k_bias = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)

# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size),
q_weight.cpu(),
eps,
norm_bias=q_bias.cpu())
_k = rms_norm(_k.reshape(-1, head_size),
k_weight.cpu(),
eps,
norm_bias=k_bias.cpu())
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)

# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)

# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
Comment on lines +58 to +214
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The two test functions test_split_qkv_rmsnorm_rope and test_split_qkv_rmsnorm_rope_with_bias are largely identical, leading to significant code duplication. This makes the tests harder to read and maintain, as any future changes would need to be applied in two places.

To improve maintainability, these can be consolidated into a single test function parameterized by a with_bias boolean flag. This will remove over 80 lines of duplicated code and make the test logic clearer.

@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(with_bias, num_tokens, num_q_heads,
                                num_kv_heads, head_size, eps, dtype, seed,
                                device):
    torch.manual_seed(seed)
    torch.set_default_device(device)
    init_device_properties_triton()

    q_hidden_size = num_q_heads * head_size
    kv_hidden_size = num_kv_heads * head_size
    qkv = torch.randn(num_tokens,
                      q_hidden_size + kv_hidden_size * 2,
                      dtype=dtype,
                      device=device)
    q_weight = torch.randn(head_size, dtype=dtype, device=device)
    k_weight = torch.randn(head_size, dtype=dtype, device=device)
    sin = torch.from_numpy(
        np.random.uniform(0, 1,
                          [num_tokens, 1, 1, head_size])).to(dtype).npu()
    cos = torch.from_numpy(
        np.random.uniform(0, 1,
                          [num_tokens, 1, 1, head_size])).to(dtype).npu()

    q_bias, k_bias = None, None
    norm_q_bias, norm_k_bias = None, None
    if with_bias:
        q_bias = torch.randn(head_size, dtype=dtype, device=device)
        k_bias = torch.randn(head_size, dtype=dtype, device=device)
        norm_q_bias = q_bias.cpu()
        norm_k_bias = k_bias.cpu()

    # fused kernel
    q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
                                              q_weight=q_weight,
                                              k_weight=k_weight,
                                              q_hidden_size=q_hidden_size,
                                              kv_hidden_size=kv_hidden_size,
                                              head_dim=head_size,
                                              eps=eps,
                                              q_bias=q_bias,
                                              k_bias=k_bias,
                                              cos=cos,
                                              sin=sin)

    # split
    _q, _k, v_gold = qkv.cpu().split(
        [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
    # norm
    _q = rms_norm(_q.reshape(-1, head_size),
                  q_weight.cpu(),
                  eps,
                  norm_bias=norm_q_bias)
    _k = rms_norm(_k.reshape(-1, head_size),
                  k_weight.cpu(),
                  eps,
                  norm_bias=norm_k_bias)
    _q = _q.reshape(num_tokens, 1, -1, head_size)
    _k = _k.reshape(num_tokens, 1, -1, head_size)

    # rope
    q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
    q_gold = q_gold.reshape(num_tokens, -1)
    k_gold = k_gold.reshape(num_tokens, -1)

    # Compare the results.
    torch.testing.assert_close(q.to(torch.float32).cpu(),
                               q_gold,
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    torch.testing.assert_close(k.to(torch.float32).cpu(),
                               k_gold,
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    torch.testing.assert_close(v.to(torch.float32).cpu(),
                               v_gold.to(torch.float32),
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    gc.collect()
    torch.npu.empty_cache()
    torch.npu.reset_peak_memory_stats()

4 changes: 2 additions & 2 deletions vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def split_qkv_rmsnorm_rope_impl(
kv_hidden_size: int,
head_dim: int,
eps: float,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
q_bias: Optional[torch.Tensor] = None,
k_bias: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
KV_BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert KV_BLOCK_SIZE == head_dim
Expand Down
Loading