[TEST]test e2e for split_qkv_rmsnorm_rope#5320
[TEST]test e2e for split_qkv_rmsnorm_rope#5320Angazenn wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces end-to-end tests for the split_qkv_rmsnorm_rope custom operator, covering cases with and without bias. The change to make bias arguments optional in the operator implementation is correct. However, the new test file contains two nearly identical test functions, one for the case with bias and one without. This significant code duplication can be avoided by merging them into a single, parameterized test function, which would improve maintainability. I've provided a suggestion to refactor the tests accordingly.
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
| @pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) | ||
| @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
| @pytest.mark.parametrize("eps", EPS) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("seed", SEEDS) | ||
| @pytest.mark.parametrize("device", DEVICES) | ||
| @torch.inference_mode() | ||
| def test_split_qkv_rmsnorm_rope(num_tokens, num_q_heads, num_kv_heads, | ||
| head_size, eps, dtype, seed, device): | ||
| torch.manual_seed(seed) | ||
| torch.set_default_device(device) | ||
| init_device_properties_triton() | ||
|
|
||
| q_hidden_size = num_q_heads * head_size | ||
| kv_hidden_size = num_kv_heads * head_size | ||
| qkv = torch.randn(num_tokens, | ||
| q_hidden_size + kv_hidden_size * 2, | ||
| dtype=dtype, | ||
| device=device) | ||
| q_weight = torch.randn(head_size, dtype=dtype, device=device) | ||
| k_weight = torch.randn(head_size, dtype=dtype, device=device) | ||
| sin = torch.from_numpy( | ||
| np.random.uniform(0, 1, | ||
| [num_tokens, 1, 1, head_size])).to(dtype).npu() | ||
| cos = torch.from_numpy( | ||
| np.random.uniform(0, 1, | ||
| [num_tokens, 1, 1, head_size])).to(dtype).npu() | ||
| # fused kernel | ||
| q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, | ||
| q_weight=q_weight, | ||
| k_weight=k_weight, | ||
| q_hidden_size=q_hidden_size, | ||
| kv_hidden_size=kv_hidden_size, | ||
| head_dim=head_size, | ||
| eps=eps, | ||
| cos=cos, | ||
| sin=sin) | ||
|
|
||
| # split | ||
| _q, _k, v_gold = qkv.cpu().split( | ||
| [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1) | ||
| # norm | ||
| _q = rms_norm(_q.reshape(-1, head_size), q_weight.cpu(), eps) | ||
| _k = rms_norm(_k.reshape(-1, head_size), k_weight.cpu(), eps) | ||
| _q = _q.reshape(num_tokens, 1, -1, head_size) | ||
| _k = _k.reshape(num_tokens, 1, -1, head_size) | ||
|
|
||
| # rope | ||
| q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu()) | ||
| q_gold = q_gold.reshape(num_tokens, -1) | ||
| k_gold = k_gold.reshape(num_tokens, -1) | ||
|
|
||
| # Compare the results. | ||
| torch.testing.assert_close(q.to(torch.float32).cpu(), | ||
| q_gold, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| torch.testing.assert_close(k.to(torch.float32).cpu(), | ||
| k_gold, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| torch.testing.assert_close(v.to(torch.float32).cpu(), | ||
| v_gold.to(torch.float32), | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| gc.collect() | ||
| torch.npu.empty_cache() | ||
| torch.npu.reset_peak_memory_stats() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||
| @pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS) | ||
| @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
| @pytest.mark.parametrize("eps", EPS) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| @pytest.mark.parametrize("seed", SEEDS) | ||
| @pytest.mark.parametrize("device", DEVICES) | ||
| @torch.inference_mode() | ||
| def test_split_qkv_rmsnorm_rope_with_bias(num_tokens, num_q_heads, | ||
| num_kv_heads, head_size, eps, dtype, | ||
| seed, device): | ||
| torch.manual_seed(seed) | ||
| torch.set_default_device(device) | ||
| init_device_properties_triton() | ||
|
|
||
| q_hidden_size = num_q_heads * head_size | ||
| kv_hidden_size = num_kv_heads * head_size | ||
| qkv = torch.randn(num_tokens, | ||
| q_hidden_size + kv_hidden_size * 2, | ||
| dtype=dtype, | ||
| device=device) | ||
| q_weight = torch.randn(head_size, dtype=dtype, device=device) | ||
| k_weight = torch.randn(head_size, dtype=dtype, device=device) | ||
| q_bias = torch.randn(head_size, dtype=dtype, device=device) | ||
| k_bias = torch.randn(head_size, dtype=dtype, device=device) | ||
| sin = torch.from_numpy( | ||
| np.random.uniform(0, 1, | ||
| [num_tokens, 1, 1, head_size])).to(dtype).npu() | ||
| cos = torch.from_numpy( | ||
| np.random.uniform(0, 1, | ||
| [num_tokens, 1, 1, head_size])).to(dtype).npu() | ||
| # fused kernel | ||
| q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv, | ||
| q_weight=q_weight, | ||
| k_weight=k_weight, | ||
| q_hidden_size=q_hidden_size, | ||
| kv_hidden_size=kv_hidden_size, | ||
| head_dim=head_size, | ||
| eps=eps, | ||
| q_bias=q_bias, | ||
| k_bias=k_bias, | ||
| cos=cos, | ||
| sin=sin) | ||
|
|
||
| # split | ||
| _q, _k, v_gold = qkv.cpu().split( | ||
| [q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1) | ||
| # norm | ||
| _q = rms_norm(_q.reshape(-1, head_size), | ||
| q_weight.cpu(), | ||
| eps, | ||
| norm_bias=q_bias.cpu()) | ||
| _k = rms_norm(_k.reshape(-1, head_size), | ||
| k_weight.cpu(), | ||
| eps, | ||
| norm_bias=k_bias.cpu()) | ||
| _q = _q.reshape(num_tokens, 1, -1, head_size) | ||
| _k = _k.reshape(num_tokens, 1, -1, head_size) | ||
|
|
||
| # rope | ||
| q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu()) | ||
| q_gold = q_gold.reshape(num_tokens, -1) | ||
| k_gold = k_gold.reshape(num_tokens, -1) | ||
|
|
||
| # Compare the results. | ||
| torch.testing.assert_close(q.to(torch.float32).cpu(), | ||
| q_gold, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| torch.testing.assert_close(k.to(torch.float32).cpu(), | ||
| k_gold, | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| torch.testing.assert_close(v.to(torch.float32).cpu(), | ||
| v_gold.to(torch.float32), | ||
| atol=DEFAULT_ATOL, | ||
| rtol=DEFAULT_RTOL) | ||
|
|
||
| gc.collect() | ||
| torch.npu.empty_cache() | ||
| torch.npu.reset_peak_memory_stats() |
There was a problem hiding this comment.
The two test functions test_split_qkv_rmsnorm_rope and test_split_qkv_rmsnorm_rope_with_bias are largely identical, leading to significant code duplication. This makes the tests harder to read and maintain, as any future changes would need to be applied in two places.
To improve maintainability, these can be consolidated into a single test function parameterized by a with_bias boolean flag. This will remove over 80 lines of duplicated code and make the test logic clearer.
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_q_heads, num_kv_heads", NUM_QKV_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("eps", EPS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_split_qkv_rmsnorm_rope(with_bias, num_tokens, num_q_heads,
num_kv_heads, head_size, eps, dtype, seed,
device):
torch.manual_seed(seed)
torch.set_default_device(device)
init_device_properties_triton()
q_hidden_size = num_q_heads * head_size
kv_hidden_size = num_kv_heads * head_size
qkv = torch.randn(num_tokens,
q_hidden_size + kv_hidden_size * 2,
dtype=dtype,
device=device)
q_weight = torch.randn(head_size, dtype=dtype, device=device)
k_weight = torch.randn(head_size, dtype=dtype, device=device)
sin = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
cos = torch.from_numpy(
np.random.uniform(0, 1,
[num_tokens, 1, 1, head_size])).to(dtype).npu()
q_bias, k_bias = None, None
norm_q_bias, norm_k_bias = None, None
if with_bias:
q_bias = torch.randn(head_size, dtype=dtype, device=device)
k_bias = torch.randn(head_size, dtype=dtype, device=device)
norm_q_bias = q_bias.cpu()
norm_k_bias = k_bias.cpu()
# fused kernel
q, k, v = torch.ops.vllm.qkv_rmsnorm_rope(input=qkv,
q_weight=q_weight,
k_weight=k_weight,
q_hidden_size=q_hidden_size,
kv_hidden_size=kv_hidden_size,
head_dim=head_size,
eps=eps,
q_bias=q_bias,
k_bias=k_bias,
cos=cos,
sin=sin)
# split
_q, _k, v_gold = qkv.cpu().split(
[q_hidden_size, kv_hidden_size, kv_hidden_size], dim=-1)
# norm
_q = rms_norm(_q.reshape(-1, head_size),
q_weight.cpu(),
eps,
norm_bias=norm_q_bias)
_k = rms_norm(_k.reshape(-1, head_size),
k_weight.cpu(),
eps,
norm_bias=norm_k_bias)
_q = _q.reshape(num_tokens, 1, -1, head_size)
_k = _k.reshape(num_tokens, 1, -1, head_size)
# rope
q_gold, k_gold = custom_rope(_q, _k, sin.cpu(), cos.cpu())
q_gold = q_gold.reshape(num_tokens, -1)
k_gold = k_gold.reshape(num_tokens, -1)
# Compare the results.
torch.testing.assert_close(q.to(torch.float32).cpu(),
q_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(k.to(torch.float32).cpu(),
k_gold,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(v.to(torch.float32).cpu(),
v_gold.to(torch.float32),
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
Any progress? If this PR is still alive, please rebase to main and make CI happy, otherwise you can close it. Thanks |
|
merged by #5267 . |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?