Skip to content

[TEST]test e2e for split_qkv_rmsnorm_rope#5320

Closed
Angazenn wants to merge 1 commit intovllm-project:mainfrom
Angazenn:triton_e2e
Closed

[TEST]test e2e for split_qkv_rmsnorm_rope#5320
Angazenn wants to merge 1 commit intovllm-project:mainfrom
Angazenn:triton_e2e

Conversation

@Angazenn
Copy link
Copy Markdown
Collaborator

@Angazenn Angazenn commented Dec 24, 2025

What this PR does / why we need it?

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: Angazenn <supperccell@163.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces end-to-end tests for the split_qkv_rmsnorm_rope custom operator, covering cases with and without bias. The change to make bias arguments optional in the operator implementation is correct. However, the new test file contains two nearly identical test functions, one for the case with bias and one without. This significant code duplication can be avoided by merging them into a single, parameterized test function, which would improve maintainability. I've provided a suggestion to refactor the tests accordingly.

Comment on lines +58 to +214
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads,
head_size, eps, dtype, seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()

q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
cos=cos,
sin=sin)

# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps)
_k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps)
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)

# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)

# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype,
seed, device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()

q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
q_bias = torch.randn(head_size, dtype=dtype, device=device)
k_bias = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)

# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size),
q_weight.cpu(),
eps,
norm_bias=q_bias.cpu())
_k = rms_norm(_k.reshape(-1, head_size),
k_weight.cpu(),
eps,
norm_bias=k_bias.cpu())
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)

# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)

# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The two test functions test_split_qkv_rmsnorm_rope and test_split_qkv_rmsnorm_rope_with_bias are largely identical, leading to significant code duplication. This makes the tests harder to read and maintain, as any future changes would need to be applied in two places.

To improve maintainability, these can be consolidated into a single test function parameterized by a with_bias boolean flag. This will remove over 80 lines of duplicated code and make the test logic clearer.

@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(with_bias, num_tokens, num_q_heads,
                                num_kv_heads, head_size, eps, dtype, seed,
                                device):
    torch.manual_seed(seed)
    torch.set_default_device(device)
    init_device_properties_triton()

    q_hidden_size = num_q_heads * head_size
    kv_hidden_size = num_kv_heads * head_size
    qkv = torch.randn(num_tokens,
                      q_hidden_size + kv_hidden_size * 2,
                      dtype=dtype,
                      device=device)
    q_weight = torch.randn(head_size, dtype=dtype, device=device)
    k_weight = torch.randn(head_size, dtype=dtype, device=device)
    sin = torch.from_numpy(
        np.random.uniform(0, 1,
                          [num_tokens, 1, 1, head_size])).to(dtype).npu()
    cos = torch.from_numpy(
        np.random.uniform(0, 1,
                          [num_tokens, 1, 1, head_size])).to(dtype).npu()

    q_bias, k_bias = None, None
    norm_q_bias, norm_k_bias = None, None
    if with_bias:
        q_bias = torch.randn(head_size, dtype=dtype, device=device)
        k_bias = torch.randn(head_size, dtype=dtype, device=device)
        norm_q_bias = q_bias.cpu()
        norm_k_bias = k_bias.cpu()

    # fused kernel
    q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
                                              q_weight=q_weight,
                                              k_weight=k_weight,
                                              q_hidden_size=q_hidden_size,
                                              kv_hidden_size=kv_hidden_size,
                                              head_dim=head_size,
                                              eps=eps,
                                              q_bias=q_bias,
                                              k_bias=k_bias,
                                              cos=cos,
                                              sin=sin)

    # split
    _q, _k, v_gold = qkv.cpu().split(
        [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
    # norm
    _q = rms_norm(_q.reshape(-1, head_size),
                  q_weight.cpu(),
                  eps,
                  norm_bias=norm_q_bias)
    _k = rms_norm(_k.reshape(-1, head_size),
                  k_weight.cpu(),
                  eps,
                  norm_bias=norm_k_bias)
    _q = _q.reshape(num_tokens, 1, -1, head_size)
    _k = _k.reshape(num_tokens, 1, -1, head_size)

    # rope
    q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
    q_gold = q_gold.reshape(num_tokens, -1)
    k_gold = k_gold.reshape(num_tokens, -1)

    # Compare the results.
    torch.testing.assert_close(q.to(torch.float32).cpu(),
                               q_gold,
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    torch.testing.assert_close(k.to(torch.float32).cpu(),
                               k_gold,
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    torch.testing.assert_close(v.to(torch.float32).cpu(),
                               v_gold.to(torch.float32),
                               atol=DEFAULT_ATOL,
                               rtol=DEFAULT_RTOL)

    gc.collect()
    torch.npu.empty_cache()
    torch.npu.reset_peak_memory_stats()

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@wangxiyuan
Copy link
Copy Markdown
Collaborator

Any progress? If this PR is still alive, please rebase to main and make CI happy, otherwise you can close it. Thanks

@Angazenn
Copy link
Copy Markdown
Collaborator Author

Angazenn commented Jan 5, 2026

merged by #5267 .

@Angazenn Angazenn closed this Jan 5, 2026
@Angazenn Angazenn deleted the triton_e2e branch February 4, 2026 06:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants