Skip to content

[XPU] Enable topk_per_row and indexer_quant_cache kernels for DeepSeekV3.2 and GLM5#37888

Open
xwu-intel wants to merge 8 commits intovllm-project:mainfrom
xwu-intel:xpu-ops-optimization
Open

[XPU] Enable topk_per_row and indexer_quant_cache kernels for DeepSeekV3.2 and GLM5#37888
xwu-intel wants to merge 8 commits intovllm-project:mainfrom
xwu-intel:xpu-ops-optimization

Conversation

@xwu-intel
Copy link
Copy Markdown

@xwu-intel xwu-intel commented Mar 23, 2026

Continue the PR #37869

Waiting for fp8_mqa_logits and fp8_paged_mqa_logits xpu kernels... due to #37968

Purpose

This PR optimizes XPU operations in vllm by integrating high-performance kernels from vllm-xpu-kernels. Specifically, it replaces PyTorch fallback implementations for:

  • top_k_per_row_prefill
  • top_k_per_row_decode
  • indexer_k_quant_and_cache
  • cp_gather_indexer_k_quant_cache

The old PyTorch fallback paths were removed.

Test Plan

  • Locally verify DeepSeek V3.2 and GLM-5 reduced model to confirm kernel availability and correct execution on B60.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request optimizes several XPU operations by replacing the Python-based fallback implementations with calls to high-performance C++ kernels from vllm-xpu-kernels. This is a valuable performance improvement. The changes also simplify the calling code by removing platform-specific branches. My review has identified a potential critical issue with a custom operator namespace and a typo in a function parameter name. Please see the detailed comments.

Comment thread vllm/_xpu_ops.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
@xwu-intel xwu-intel marked this pull request as ready for review March 23, 2026 13:22
@xwu-intel xwu-intel force-pushed the xpu-ops-optimization branch from 81382a2 to 8a98984 Compare March 23, 2026 13:33
@xwu-intel
Copy link
Copy Markdown
Author

@jikunshang @wuxun-zhang pls review. better to merged when the next vllm-xpu-kernels released

Comment thread vllm/model_executor/layers/sparse_attn_indexer.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
self.max_total_seq_len,
self.topk_indices_buffer,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chaojun-zhang will vLLM IR can help avoid these op dispatches using duplicated code?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg suggestions here? Seems this custom op doesn't have native implementation, how to handle it with vLLM IR?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this case, it's simple there would be a single forward method only and call the registered IR kernel directly regardless platforms. XPU will register its own IR kernel, which should be same one as cuda. Expect no more duplicated code.

Comment thread vllm/model_executor/layers/sparse_attn_indexer.py
@jikunshang jikunshang changed the title Optimize XPU ops using latest vllm-xpu-kernels [XPU] Optimize XPU ops using latest vllm-xpu-kernels Mar 24, 2026
@wuxun-zhang
Copy link
Copy Markdown
Contributor

wuxun-zhang commented Mar 25, 2026

@xwu-intel This PR is going to enable xpu indexer related kernels for DeepSeek V3.2, I would suggest to change title to reflect this, something like enable topk_per_row and indexer_quant_cache kernels for DeepSeekV3.2.

@xwu-intel xwu-intel changed the title [XPU] Optimize XPU ops using latest vllm-xpu-kernels [XPU] Enable topk_per_row and indexer_quant_cache kernels for DeepSeekV3.2 and GLM5 Mar 26, 2026
@mergify mergify bot added deepseek Related to DeepSeek models intel-gpu Related to Intel GPU labels Mar 26, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 1, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xwu-intel.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 1, 2026
@jikunshang
Copy link
Copy Markdown
Collaborator

I think we can restart this as we bump up vllm-xpu-kernels 0.1.5 release.

Comment thread vllm/model_executor/layers/sparse_attn_indexer.py Outdated
@wuxun-zhang
Copy link
Copy Markdown
Contributor

@xwu-intel Please rebase the PR.

Signed-off-by: Xiaochang Wu <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
@xwu-intel xwu-intel force-pushed the xpu-ops-optimization branch from f127edb to 08ec43a Compare April 7, 2026 03:00
@mergify mergify bot removed the needs-rebase label Apr 7, 2026
@wuxun-zhang
Copy link
Copy Markdown
Contributor

cc @xinyu-intel

Comment thread vllm/_xpu_ops.py Outdated
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
@xwu-intel
Copy link
Copy Markdown
Author

Waiting for fp8_mqa_logits and fp8_paged_mqa_logits xpu kernels... due to #37968

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Apr 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xwu-intel.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models intel-gpu Related to Intel GPU needs-rebase

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants