diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 97ccc4a6c0f..30d7eb4c606 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -103,6 +103,7 @@ jobs: # We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run # the test separately. + pytest -sv --durations=0 tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py pytest -sv --durations=0 tests/e2e/singlecard/test_completion_with_prompt_embeds.py pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py diff --git a/tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py b/tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py new file mode 100644 index 00000000000..9a3b2524b12 --- /dev/null +++ b/tests/e2e/singlecard/test_split_qkv_rmsnorm_rope.py @@ -0,0 +1,214 @@ +import gc + +import numpy as np +import pytest +import torch + +import vllm_ascend.ops.register_custom_ops # noqa +from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + +NUM_TOKENS = [1, 4, 8, 16, 1024] +NUM_QKV_HEADS = [(12, 1), (16, 1), (32, 4), (64, 4)] +HEAD_SIZES = [128] +EPS = [1e-6] +DTYPES = [torch.bfloat16] +SEEDS = [0] +DEVICES = [f"npu:{0}"] +DEFAULT_ATOL = 5e-2 +DEFAULT_RTOL = 5e-3 + + +def custom_rope(q, k, sin, cos): + rotary_dim = sin.shape[-1] + sin = sin.to(torch.float32) + cos = cos.to(torch.float32) + x1 = q[..., :rotary_dim // 2] + x2 = q[..., rotary_dim // 2:] + cat_x = torch.cat([-x2, x1], axis=-1) + mul1 = cat_x * sin + mul2 = q * cos + res1 = mul1 + mul2 + + x1 = k[..., :rotary_dim // 2] + x2 = k[..., rotary_dim // 2:] + cat_x = torch.cat([-x2, x1], axis=-1) + mul1 = cat_x * sin + mul2 = k * cos + res2 = mul1 + mul2 + return res1, res2 + + +def rms_norm( + input, + norm_weight, + eps, + norm_bias=None, +): + input = input.to(torch.float32) + norm_weight = norm_weight.to(torch.float32) + reciprocal_std = 1 / torch.sqrt( + torch.mean(input**2, axis=-1, keepdims=True) + eps) + out = input * reciprocal_std * norm_weight + if norm_bias is not None: + norm_bias = norm_bias.to(torch.float32) + out = out + norm_bias + return out + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("eps", EPS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, + head_size, eps, dtype, seed, device): + torch.manual_seed(seed) + torch.set_default_device(device) + init_device_properties_triton() + + q_hidden_size = num_q_heads * head_size + kv_hidden_size = num_kv_heads * head_size + qkv = torch.randn(num_tokens, + q_hidden_size + kv_hidden_size * 2, + dtype=dtype, + device=device) + q_weight = torch.randn(head_size, dtype=dtype, device=device) + k_weight = torch.randn(head_size, dtype=dtype, device=device) + sin = torch.from_numpy( + np.random.uniform(0, 1, + [num_tokens, 1, 1, head_size])).to(dtype).npu() + cos = torch.from_numpy( + np.random.uniform(0, 1, + [num_tokens, 1, 1, head_size])).to(dtype).npu() + # fused kernel + q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=q_hidden_size, + kv_hidden_size=kv_hidden_size, + head_dim=head_size, + eps=eps, + cos=cos, + sin=sin) + + # split + _q, _k, v_gold = qkv.cpu().split( + [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1) + # norm + _q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps) + _k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps) + _q = _q.reshape(num_tokens, 1, -1, head_size) + _k = _k.reshape(num_tokens, 1, -1, head_size) + + # rope + q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu()) + q_gold = q_gold.reshape(num_tokens, -1) + k_gold = k_gold.reshape(num_tokens, -1) + + # Compare the results. + torch.testing.assert_close(q.to(torch.float32).cpu(), + q_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(k.to(torch.float32).cpu(), + k_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(v.to(torch.float32).cpu(), + v_gold.to(torch.float32), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("eps", EPS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads, + num_kv_heads, head_size, eps, dtype, + seed, device): + torch.manual_seed(seed) + torch.set_default_device(device) + init_device_properties_triton() + + q_hidden_size = num_q_heads * head_size + kv_hidden_size = num_kv_heads * head_size + qkv = torch.randn(num_tokens, + q_hidden_size + kv_hidden_size * 2, + dtype=dtype, + device=device) + q_weight = torch.randn(head_size, dtype=dtype, device=device) + k_weight = torch.randn(head_size, dtype=dtype, device=device) + q_bias = torch.randn(head_size, dtype=dtype, device=device) + k_bias = torch.randn(head_size, dtype=dtype, device=device) + sin = torch.from_numpy( + np.random.uniform(0, 1, + [num_tokens, 1, 1, head_size])).to(dtype).npu() + cos = torch.from_numpy( + np.random.uniform(0, 1, + [num_tokens, 1, 1, head_size])).to(dtype).npu() + # fused kernel + q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=q_hidden_size, + kv_hidden_size=kv_hidden_size, + head_dim=head_size, + eps=eps, + q_bias=q_bias, + k_bias=k_bias, + cos=cos, + sin=sin) + + # split + _q, _k, v_gold = qkv.cpu().split( + [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1) + # norm + _q = rms_norm(_q.reshape(-1, head_size), + q_weight.cpu(), + eps, + norm_bias=q_bias.cpu()) + _k = rms_norm(_k.reshape(-1, head_size), + k_weight.cpu(), + eps, + norm_bias=k_bias.cpu()) + _q = _q.reshape(num_tokens, 1, -1, head_size) + _k = _k.reshape(num_tokens, 1, -1, head_size) + + # rope + q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu()) + q_gold = q_gold.reshape(num_tokens, -1) + k_gold = k_gold.reshape(num_tokens, -1) + + # Compare the results. + torch.testing.assert_close(q.to(torch.float32).cpu(), + q_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(k.to(torch.float32).cpu(), + k_gold, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + torch.testing.assert_close(v.to(torch.float32).cpu(), + v_gold.to(torch.float32), + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py index 58e56ae5ee5..6bc2d373249 100644 --- a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -209,8 +209,8 @@ def split_qkv_rmsnorm_rope_impl( kv_hidden_size: int, head_dim: int, eps: float, - q_bias: Optional[torch.Tensor], - k_bias: Optional[torch.Tensor], + q_bias: Optional[torch.Tensor] = None, + k_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) assert KV_BLOCK_SIZE == head_dim