Skip to content

[XPU] fp8_mqa_logits and fp8_paged_mqa_logits torch fallbacks for XPU#39156

Closed
xwu-intel wants to merge 11 commits intovllm-project:mainfrom
xwu-intel:xpu-fp8-logits-torch-fallback
Closed

[XPU] fp8_mqa_logits and fp8_paged_mqa_logits torch fallbacks for XPU#39156
xwu-intel wants to merge 11 commits intovllm-project:mainfrom
xwu-intel:xpu-fp8-logits-torch-fallback

Conversation

@xwu-intel
Copy link
Copy Markdown

@xwu-intel xwu-intel commented Apr 7, 2026

Purpose

fp8_mqa_logits and fp8_paged_mqa_logits fallbacks were removed in #37968. XPU path still requires them until the ops are implemented in vllm-xpu-kernels.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

xwu-intel added 11 commits April 7, 2026 10:55
Signed-off-by: Xiaochang Wu <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Add xpu_ops.fp8_mqa_logits_torch and xpu_ops.fp8_paged_mqa_logits_torch, and route XPU sparse_attn_indexer prefill/decode logits through these fallbacks.

Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Add xpu_ops.fp8_mqa_logits_torch and xpu_ops.fp8_paged_mqa_logits_torch, and use them only for XPU logits paths in sparse_attn_indexer.

Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
@mergify mergify bot added the intel-gpu Related to Intel GPU label Apr 7, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces XPU-specific PyTorch implementations for FP8 Multi-Query Attention (MQA) and paged MQA logits calculation, integrating them into the sparse attention indexer. The review identifies critical memory efficiency issues in these new methods: the fp8_mqa_logits_torch implementation uses torch.einsum in a way that risks out-of-memory errors for large sequences, and fp8_paged_mqa_logits_torch inefficiently dequantizes the entire KV cache block pool. Suggestions were provided to iterate over heads and dequantize blocks lazily to mitigate these risks.

Comment thread vllm/_xpu_ops.py
Comment on lines +455 to +456
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of torch.einsum to compute a [H, M, N] tensor poses a critical risk of out-of-memory (OOM) errors for large sequence lengths. For instance, with 128 heads and a 1GB logits budget, this intermediate tensor would require 128GB of memory. It is highly recommended to iterate over the heads and accumulate the results to keep memory usage bounded by [M, N].

Suggested change
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = torch.zeros((q.shape[0], seq_len_kv),
device=q.device,
dtype=torch.float32)
for h in range(q.shape[1]):
score_h = (q[:, h, :] @ k.T).float() * scale
logits += score_h.relu() * weights[:, h, None]

Comment thread vllm/_xpu_ops.py
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dequantizing the entire kv_cache block pool into float32 is extremely memory-intensive and inefficient, especially for large context windows where the block pool can be very large. This allocation happens on every call and can lead to OOM or significant performance degradation. It is better to dequantize only the specific blocks required for the current batch inside the loop.

@xwu-intel xwu-intel closed this Apr 7, 2026
@xwu-intel xwu-intel deleted the xpu-fp8-logits-torch-fallback branch April 7, 2026 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

intel-gpu Related to Intel GPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant