[XPU] fp8_mqa_logits and fp8_paged_mqa_logits torch fallbacks for XPU#39156
[XPU] fp8_mqa_logits and fp8_paged_mqa_logits torch fallbacks for XPU#39156xwu-intel wants to merge 11 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Xiaochang Wu <xiaochang.wu@intel.com> Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Add xpu_ops.fp8_mqa_logits_torch and xpu_ops.fp8_paged_mqa_logits_torch, and route XPU sparse_attn_indexer prefill/decode logits through these fallbacks. Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
Add xpu_ops.fp8_mqa_logits_torch and xpu_ops.fp8_paged_mqa_logits_torch, and use them only for XPU logits paths in sparse_attn_indexer. Signed-off-by: Wu, Xiaochang <xiaochang.wu@intel.com>
There was a problem hiding this comment.
Code Review
This pull request introduces XPU-specific PyTorch implementations for FP8 Multi-Query Attention (MQA) and paged MQA logits calculation, integrating them into the sparse attention indexer. The review identifies critical memory efficiency issues in these new methods: the fp8_mqa_logits_torch implementation uses torch.einsum in a way that risks out-of-memory errors for large sequences, and fp8_paged_mqa_logits_torch inefficiently dequantizes the entire KV cache block pool. Suggestions were provided to iterate over heads and dequantize blocks lazily to mitigate these risks.
| score = torch.einsum("mhd,nd->hmn", q, k).float() * scale | ||
| logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) |
There was a problem hiding this comment.
The use of torch.einsum to compute a [H, M, N] tensor poses a critical risk of out-of-memory (OOM) errors for large sequence lengths. For instance, with 128 heads and a 1GB logits budget, this intermediate tensor would require 128GB of memory. It is highly recommended to iterate over the heads and accumulate the results to keep memory usage bounded by [M, N].
| score = torch.einsum("mhd,nd->hmn", q, k).float() * scale | |
| logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) | |
| logits = torch.zeros((q.shape[0], seq_len_kv), | |
| device=q.device, | |
| dtype=torch.float32) | |
| for h in range(q.shape[1]): | |
| score_h = (q[:, h, :] @ k.T).float() * scale | |
| logits += score_h.relu() * weights[:, h, None] |
| kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] | ||
| scale = scale.contiguous().view(torch.float) | ||
| q = q.float() | ||
| kv_cache = kv_cache.view(fp8_dtype).float() * scale |
There was a problem hiding this comment.
Dequantizing the entire kv_cache block pool into float32 is extremely memory-intensive and inefficient, especially for large context windows where the block pool can be very large. This allocation happens on every call and can lead to OOM or significant performance degradation. It is better to dequantize only the specific blocks required for the current batch inside the loop.
Purpose
fp8_mqa_logits and fp8_paged_mqa_logits fallbacks were removed in #37968. XPU path still requires them until the ops are implemented in vllm-xpu-kernels.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.