Skip to content

[model_runner_v2]optimize the performance of the _topk_log_softmax_kernel#7221

Merged
wangxiyuan merged 6 commits intovllm-project:mainfrom
wangx700:logprobs_mr_v2
Mar 16, 2026
Merged

[model_runner_v2]optimize the performance of the _topk_log_softmax_kernel#7221
wangxiyuan merged 6 commits intovllm-project:mainfrom
wangx700:logprobs_mr_v2

Conversation

@wangx700
Copy link
Copy Markdown
Contributor

@wangx700 wangx700 commented Mar 13, 2026

What this PR does / why we need it?

Optimize the performance of the triton operator _topk_log_softmax_kernel in model_runner_v2 to 1.04xH100,which is 7% of its original value.(issue #5208)

Does this PR introduce any user-facing change?

How was this patch tested?

…rnel Triton operator

in model_runner_v2 to 7% of its original value.

Signed-off-by: wangx700 <wangxin700@huawei.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of log probability calculations within the model_runner_v2 component. It achieves this by introducing a highly optimized Triton kernel for _topk_log_softmax_kernel, resulting in a substantial speedup. The changes also include comprehensive testing to ensure the accuracy and reliability of the new kernel.

Highlights

  • Performance Optimization: The _topk_log_softmax_kernel in model_runner_v2 has been significantly optimized, achieving a performance improvement to 7% of its original execution time.
  • New Triton Kernel Implementation: A new Triton-based kernel, _topk_log_softmax_kernel, was introduced to efficiently compute log probabilities for top-k tokens.
  • Test Coverage: A new end-to-end test was added to validate the correctness and accuracy of the _topk_log_softmax_kernel against a PyTorch reference implementation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_log_softmax.py
    • Added a new test file for _topk_log_softmax_kernel.
  • vllm_ascend/worker/v2/sample/logprob.py
    • Added a new file containing the _topk_log_softmax_kernel Triton implementation.
    • Added the compute_token_logprobs function utilizing the new kernel.
Activity
  • No activity has occurred on this pull request since its creation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@wangx700 wangx700 changed the title [model_runner_v2]Optimize the performance of the _topk_log_softmax_kernel [model_runner_v2]optimize the performance of the _topk_log_softmax_kernel Mar 13, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized Triton kernel _topk_log_softmax_kernel for calculating log probabilities, claiming a significant performance improvement. It also adds a new test to verify the kernel's correctness. My review includes suggestions to further improve the kernel's efficiency and to enhance the new test's implementation for better performance and readability. Additionally, I've provided an updated PR title and summary to align with the repository's contribution guidelines.

Suggested PR Title:

[Ops][Perf] Optimize the performance of _topk_log_softmax_kernel

Suggested PR Summary:

### What this PR does / why we need it?

This PR optimizes the `_topk_log_softmax_kernel` Triton kernel, which is used for calculating log probabilities for specified tokens. The key improvements are within the kernel's implementation of the log-softmax operation, which are intended to reduce execution time significantly.

A new test is also added to verify the correctness of the optimized kernel by comparing its output against the standard PyTorch `log_softmax` implementation.

### Does this PR introduce _any_ user-facing change?

No, this is a backend performance optimization and does not introduce any user-facing changes.

### How was this patch tested?

A new test, `test_topk_log_softmax_kernel`, has been added to `tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_log_softmax.py`. This test covers various configurations of batch size, vocabulary size, and number of logprobs. It validates the Triton kernel's output against a reference PyTorch implementation, ensuring correctness.

Comment on lines +46 to +54
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
block = block.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop for calculating the sum of exponentials (se) can be made more efficient and readable. Currently, it loads values with other=0.0, computes exp, and then uses tl.where to mask out-of-bound elements. This performs unnecessary computations.

A better approach is to load with other=float("-inf"). This way, tl.exp on out-of-bound elements will naturally result in 0.0, removing the need for an explicit tl.where call. Additionally, the type conversion block.to(tl.float32) is unnecessary as block is only used for masking and pointer arithmetic.

Suggested change
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
block = block.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
se += tl.sum(e)
for i in range(0, vocab_size, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
logits = logits.to(tl.float32)
e = tl.exp(logits - max_val)
se += tl.sum(e)

Comment on lines +55 to +59
ref_output = torch.zeros_like(triton_output)
for i in range(batch_size):
for j in range(num_logprobs):
token_id = token_ids[i, j]
ref_output[i, j] = torch_logprobs[i, token_id]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reference output calculation uses nested Python loops, which is inefficient and less readable than a vectorized approach. You can achieve the same result more efficiently and concisely by using torch.gather.

Suggested change
ref_output = torch.zeros_like(triton_output)
for i in range(batch_size):
for j in range(num_logprobs):
token_id = token_ids[i, j]
ref_output[i, j] = torch_logprobs[i, token_id]
ref_output = torch.gather(torch_logprobs, 1, token_ids)

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: wangx700 <wangxin700@huawei.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
@Ronald1995
Copy link
Copy Markdown
Contributor

please refer to #5208.

logits = logits.to(tl.float32)
block = block.to(tl.float32)
e = tl.exp(logits - max_val)
e = tl.where(block < vocab_size, e, 0.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some comments to explain the difference between vllm and vllm-ascend.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,I added some notes to explain the difference between vllm and vllm-ascend.

Signed-off-by: wangx700 <wangxin700@huawei.com>
Signed-off-by: wangx700 <wangxin700@huawei.com>
@wangxiyuan wangxiyuan merged commit 22d0e1d into vllm-project:main Mar 16, 2026
36 checks passed
@Ronald1995 Ronald1995 mentioned this pull request Mar 16, 2026
35 tasks
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…rnel (vllm-project#7221)

### What this PR does / why we need it?
Optimize the performance of the triton operator _topk_log_softmax_kernel
in model_runner_v2 to 1.04xH100,which is 7% of its original value.(issue
vllm-project#5208)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: wangx700 <wangxin700@huawei.com>
ichaoren pushed a commit to ichaoren/vllm-ascend that referenced this pull request Mar 17, 2026
…rnel (vllm-project#7221)

### What this PR does / why we need it?
Optimize the performance of the triton operator _topk_log_softmax_kernel
in model_runner_v2 to 1.04xH100,which is 7% of its original value.(issue
vllm-project#5208)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: wangx700 <wangxin700@huawei.com>
Signed-off-by: xutianyi <xutianyi5@huawei.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Mar 18, 2026
…scend into qwen3next_graph

* 'qwen3next_graph' of https://github.com/845473182/vllm-ascend: (62 commits)
  [doc] Refresh the documentation for DeepSeek-V3.2 (vllm-project#7403)
  [bugfix][accuracy] Fix ds indexer accuracy problem caused by k rope (vllm-project#7341)
  [P/D] LayerwiseConnector supports the virtual push functionality on node D. (vllm-project#7361)
  [CI] Add PAT_TOKEN when checkout (vllm-project#7400)
  [main2main] upgrade vllm to 0308 (vllm-project#7213)
  [CI] add scheduled stale issue management (vllm-project#7354)
  [CI] expand issue labeler rules for feature/model triage (vllm-project#7356)
  [Bugfix] Assertion error when decode prefix cache fully hits (vllm-project#7236)
  [doc] Refresh the documentation for GLM-4.7 (vllm-project#7292)
  [BugFix]A2 MOE method&& layerwise MTP bugfix && Mamba gdn_metadata bugfix (vllm-project#7364)
  [doc] Upload doc for qwen3.5-27B and qwen3.5-397B-A17B on Ascend (vllm-project#7313)
  [bugfix]Enable dispatch_ffn_combine feature for qwen3.5 (vllm-project#7066)
  [bugfix] fix unzip file path for fia operator (vllm-project#7367)
  [Perf] Optimize bias handling in AscendRMSNorm (vllm-project#7226)
  [eagle3][pcp] fix bug for eagle3 and cp enable (vllm-project#7309)
  [Bugfix] fix TransposeKvCacheByBlock op error report in plog (vllm-project#7235)
  [Feature]Supports DSv3.1 PD separation and C8 quantization (vllm-project#7222)
  [main][bugfix] Fixed the problem that eagle3 will crash in FULL_DECODE_ONLY (vllm-project#7290)
  [xlite][Bugfix] Support mrope and deepstack features in xlite backend (vllm-project#7295)
  [model_runner_v2]optimize the performance of the _topk_log_softmax_kernel (vllm-project#7221)
  ...
winson-00178005 added a commit to winson-00178005/vllm-ascend that referenced this pull request Mar 26, 2026
- Remove is_skipped flag from tests/e2e/singlecard/model_runner_v2/test_basic.py
- Test was originally skipped due to get_cuda_view_from_cpu_tensor error (vllm-project#5752)
- Recent model_runner_v2 improvements may have resolved the issue:
  - vllm-project#7110: Added aclgraph support
  - vllm-project#7496: Optimized post_update performance
  - vllm-project#7221: Optimized _topk_log_softmax_kernel performance
- CI will verify if the test now passes successfully

Signed-off-by: hejianping <hejianping7@huawei.com>
winson-00178005 added a commit to winson-00178005/vllm-ascend that referenced this pull request Mar 26, 2026
- Remove is_skipped flag from tests/e2e/singlecard/model_runner_v2/test_basic.py
- Test was originally skipped due to get_cuda_view_from_cpu_tensor error (vllm-project#5752)
- Recent model_runner_v2 improvements may have resolved the issue:
  - vllm-project#7110: Added aclgraph support
  - vllm-project#7496: Optimized post_update performance
  - vllm-project#7221: Optimized _topk_log_softmax_kernel performance
- CI will verify if test now passes successfully

Signed-off-by: hejianping <hejianping7@huawei.com>
winson-00178005 added a commit to winson-00178005/vllm-ascend that referenced this pull request Mar 26, 2026
- Remove is_skipped flag from tests/e2e/singlecard/model_runner_v2/test_basic.py
- Test was originally skipped due to get_cuda_view_from_cpu_tensor error (vllm-project#5752)
- Recent model_runner_v2 improvements may have resolved the issue:
  - vllm-project#7110: Added aclgraph support
  - vllm-project#7496: Optimized post_update performance
  - vllm-project#7221: Optimized _topk_log_softmax_kernel performance
- CI will verify if the test now passes successfully

Signed-off-by: hejianping <hejianping7@huawei.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…rnel (vllm-project#7221)

### What this PR does / why we need it?
Optimize the performance of the triton operator _topk_log_softmax_kernel
in model_runner_v2 to 1.04xH100,which is 7% of its original value.(issue
vllm-project#5208)

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: wangx700 <wangxin700@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants