-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[model_runner_v2]optimize the performance of the post_update. #7496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
weijinqian0
merged 6 commits into
vllm-project:main
from
weijinqian0:triton_post_update
Mar 23, 2026
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
5f453a4
[model_runner_v2]optimize the performance of the _topk_log_softmax_ke…
7a5eb8a
[model_runner_v2]optimize the performance of the post_update.
8e85b15
Merge branch 'main' into triton_post_update
weijinqian0 7388b09
Merge branch 'main' into triton_post_update
weijinqian0 a454c55
Merge branch 'main' into triton_post_update
weijinqian0 03dd028
[model_runner_v2]optimize the performance of the post_update.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
97 changes: 97 additions & 0 deletions
97
tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_post_update.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| from typing import Dict, Any | ||
|
|
||
| import torch | ||
| import pytest | ||
| from vllm.v1.worker.gpu.input_batch import post_update as post_update_gpu | ||
| from vllm_ascend.worker.v2.input_batch import post_update as post_update_npu | ||
|
|
||
|
|
||
| def generate_test_data(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int, device: str) -> \ | ||
| Dict[str, Any]: | ||
| """ | ||
| Generate random test data. | ||
| Return a dictionary containing all input tensors and the additional field 'expected_query_lens' for validation. | ||
| """ | ||
| num_cols = num_speculative_steps + 1 | ||
|
|
||
| if num_reqs > max_num_reqs: | ||
| raise ValueError("num_reqs cannot be larger than max_num_reqs") | ||
|
|
||
| idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) | ||
| num_computed_tokens = torch.randint(0, 100, (max_num_reqs,), dtype=torch.int32, device=device) | ||
| last_sampled_tokens = torch.randint(0, vocab_size, (max_num_reqs,), dtype=torch.int32, device=device) | ||
| output_bin_counts = torch.randint(0, 10, (max_num_reqs, vocab_size), dtype=torch.int32, device=device) | ||
| sampled_tokens = torch.randint(0, vocab_size, (num_reqs, num_speculative_steps + 1), dtype=torch.int32, | ||
| device=device) | ||
| num_sampled = torch.randint(1, num_speculative_steps + 2, (num_reqs,), dtype=torch.int32, device=device) | ||
| num_rejected = torch.randint(0, num_speculative_steps + 1, (num_reqs,), dtype=torch.int32, device=device) | ||
| num_rejected = torch.min(num_rejected, num_sampled - 1) | ||
|
|
||
| query_lengths = torch.randint(1, 20, (num_reqs,), dtype=torch.int32, device=device) | ||
| query_start_loc = torch.cat([ | ||
| torch.tensor([0], dtype=torch.int32, device=device), | ||
| torch.cumsum(query_lengths, dim=0) | ||
| ]) | ||
| total_len = torch.randint(50, 200, (max_num_reqs,), dtype=torch.int32, device=device) | ||
|
|
||
| max_model_len = 3000 # 或者可以从total_len的最大值获取 | ||
| all_token_ids = torch.randint(0, vocab_size, (max_num_reqs, max_model_len), dtype=torch.int32, device=device) | ||
|
|
||
| return { | ||
| "idx_mapping": idx_mapping, | ||
| "num_computed_tokens": num_computed_tokens, | ||
| "last_sampled_tokens": last_sampled_tokens, | ||
| "output_bin_counts": output_bin_counts, | ||
| "sampled_tokens": sampled_tokens, | ||
| "num_sampled": num_sampled, | ||
| "num_rejected": num_rejected, | ||
| "query_start_loc": query_start_loc, | ||
| "all_token_ids": all_token_ids, | ||
| "total_len": total_len | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_reqs,max_num_reqs,vocab_size,num_speculative_steps", [ | ||
| (36, 36, 200, 2), | ||
| (48, 48, 32000, 5), | ||
| (128, 128, 32000, 5), | ||
| ]) | ||
| def test_post_update(num_reqs: int, max_num_reqs: int, vocab_size: int, num_speculative_steps: int): | ||
| """Test _topk_log_softmax_kernel for computing log probabilities | ||
| Args: | ||
| batch_size: Number of sequences in the batch | ||
| vocab_size: Size of the vocabulary | ||
| num_logprobs: Number of tokens to compute log probabilities for | ||
| """ | ||
| torch.manual_seed(42) | ||
|
|
||
| post_update_params = ["idx_mapping", | ||
| "num_computed_tokens", | ||
| "last_sampled_tokens", | ||
| "output_bin_counts", | ||
| "sampled_tokens", | ||
| "num_sampled", | ||
| "num_rejected", | ||
| "query_start_loc", | ||
| "all_token_ids", | ||
| "total_len" | ||
| ] | ||
|
|
||
| data = generate_test_data(num_reqs, max_num_reqs, vocab_size, num_speculative_steps, device="npu") | ||
| kernel_inputs_gpu = {k: data[k].clone() for k in post_update_params} | ||
| kernel_inputs_npu = {k: data[k].clone() for k in post_update_params} | ||
|
|
||
| # Invoke Triton kernel | ||
| post_update_gpu(**kernel_inputs_gpu) | ||
| torch.npu.synchronize() | ||
|
|
||
| post_update_npu(**kernel_inputs_npu) | ||
| torch.npu.synchronize() | ||
|
|
||
| # ========== Verify results ========== | ||
| assert torch.allclose(kernel_inputs_gpu["output_bin_counts"], kernel_inputs_npu["output_bin_counts"], rtol=1e-3, | ||
| atol=1e-3), \ | ||
| f"Triton output differs from PyTorch reference.\n" \ | ||
| f"Max diff: {torch.max(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}\n" \ | ||
| f"Mean diff: {torch.mean(torch.abs(kernel_inputs_npu['output_bin_counts'] - kernel_inputs_npu['output_bin_counts']))}" | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.