feat: add DeepSeek-V4 XPU attention decode path#42953
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces XPU support for DeepSeek V4 attention and sparse indexing, including new Triton kernels for normalization, RoPE, and sparse MLA. It also implements a dequantization strategy for FP8 KV caches on XPU. The reviewer identified several performance bottlenecks in the decode path, such as nested Python loops and frequent memory allocations, recommending vectorization and the use of a workspace manager. Additionally, feedback includes correcting hardware-specific FP8 types for XPU and removing debug print statements.
| for i in range(batch_size): | ||
| seq_len = int(context_lens[i].item()) | ||
| if seq_len <= 0: | ||
| continue | ||
| num_pages = cdiv(seq_len, block_size) | ||
| pages = block_table[i, :num_pages] | ||
| padded_seq_len = num_pages * block_size | ||
|
|
||
| # Gather K and scale from block-packed layout | ||
| # For each page, K region: [page_start, page_start + bs*dim) | ||
| # Scale region: [page_start + bs*dim, page_start + bs*dim + bs*4) | ||
| page_starts = pages.to(torch.int64) * block_stride | ||
|
|
||
| # K indices: for each page p, for each pos j: page_start + j*dim + d | ||
| pos_offsets = torch.arange(block_size, device=device) | ||
| dim_offsets = torch.arange(dim, device=device) | ||
| # [num_pages, block_size, dim] | ||
| k_indices = ( | ||
| page_starts[:, None, None] | ||
| + pos_offsets[None, :, None] * dim | ||
| + dim_offsets[None, None, :] | ||
| ) | ||
| k_uint8 = cache_flat[k_indices.reshape(-1)].reshape( | ||
| padded_seq_len, dim | ||
| ) | ||
| k_fp8 = k_uint8.view(fp8_dtype).to(torch.float32) | ||
|
|
||
| # Scale indices: for each page p, for each pos j: | ||
| # page_start + bs*dim + j*4 + byte | ||
| scale_byte_offsets = torch.arange(4, device=device) | ||
| scale_indices = ( | ||
| page_starts[:, None, None] | ||
| + block_size * dim | ||
| + pos_offsets[None, :, None] * 4 | ||
| + scale_byte_offsets[None, None, :] | ||
| ) | ||
| scale_uint8 = cache_flat[scale_indices.reshape(-1)].reshape( | ||
| padded_seq_len, 4 | ||
| ) | ||
| k_scale = scale_uint8.view(torch.float32) # [padded_seq_len, 1] | ||
|
|
||
| for n in range(next_n): | ||
| q_i = q[i, n].to(torch.float32) # [H, D] | ||
| w_i = weights[i * next_n + n] # [H] | ||
|
|
||
| # Compute per-head scores: [H, padded_seq_len] | ||
| scores = torch.mm(q_i, k_fp8[:padded_seq_len].T) # [H, S] | ||
| scores = scores * k_scale[:padded_seq_len].T # broadcast scale | ||
| scores = torch.relu(scores) | ||
| # Weight and sum over heads | ||
| weighted = (scores * w_i[:, None]).sum(dim=0) # [S] | ||
| logits[i * next_n + n, :seq_len] = weighted[:seq_len] | ||
|
|
||
| return logits |
| workspace = torch.empty( | ||
| (num_tokens * K_total, OUTPUT_DIM), dtype=torch.bfloat16, device=device | ||
| ) | ||
| ws_3d = workspace.view(num_tokens, K_total, OUTPUT_DIM) | ||
|
|
||
| # Dequant+gather topk slots from compressed cache | ||
| if not swa_only and topk_idx_2d is not None and kv_cache is not None: | ||
| topk_flat = topk_idx_2d.reshape(-1).to(torch.int32) | ||
| topk_buf = torch.empty( | ||
| (num_tokens * max_topk, OUTPUT_DIM), dtype=torch.bfloat16, device=device | ||
| ) | ||
| compressed_block_size = kv_cache.shape[1] | ||
| dequant_gather_slots(topk_buf, kv_cache, topk_flat, compressed_block_size) | ||
| ws_3d[:, :max_topk, :] = topk_buf.view(num_tokens, max_topk, OUTPUT_DIM) | ||
|
|
||
| # Dequant+gather SWA slots | ||
| swa_flat = swa_idx_2d.reshape(-1).to(torch.int32) | ||
| swa_buf = torch.empty( |
There was a problem hiding this comment.
| for t_idx in range(num_tokens): | ||
| tlen = int(topk_lens[t_idx].item()) | ||
| slen = int(swa_lens[t_idx].item()) | ||
| combined_indices[t_idx, tlen:tlen + slen] = ( | ||
| swa_ws_indices[t_idx, :slen] | ||
| ) |
| num_heads = q.shape[1] | ||
|
|
||
| # Allocate temp buffer for RoPE-applied KV | ||
| kv_roped = torch.empty_like(kv) |
|
@jikunshang @xinyu-intel @xuechendi @wuxun-zhang Please help to take a review. |
|
Thank you for your contribution. However, I want to hold this from merging due to this RFC #42770 |
2c3b796 to
3906267
Compare
3906267 to
9d681e0
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
9d681e0 to
e03711e
Compare
e03711e to
18cc072
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
I think we can go proceed this PR. Could you take a look at current dsv4 file structure and create a separate folder exclusively to intel/xpu? Thanks |
18cc072 to
ffe42a9
Compare
|
Hi @majian4work, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ffe42a9 to
6ad6803
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
Also cc @WoosukKwon |
7d35541 to
45f5043
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
45f5043 to
757ec31
Compare
|
Hi @majian4work, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
757ec31 to
eae4f20
Compare
|
@zyongye @WoosukKwon can you take another look? |
877525b to
f106cfa
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
065c5bb to
78cf322
Compare
|
Hi @majian4work, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
78cf322 to
335b3a4
Compare
| # negligible overhead for small scale tensors) ensures the kernel sees | ||
| # matching dtypes and correctly enters the block-quant path with the | ||
| # actual group_size derived from scale tensor shapes. | ||
| if scale.dtype == torch.float8_e8m0fnu: |
There was a problem hiding this comment.
will be fixed in kernel side after vllm-project/vllm-xpu-kernels#398
cc @Yejing-Lai
| elif current_platform.is_xpu(): | ||
| from .xpu.model import DeepseekV4ForCausalLM # type: ignore[assignment] | ||
| from .xpu.mtp import DeepSeekV4MTP # type: ignore[assignment] | ||
| else: |
There was a problem hiding this comment.
not certain whether cpu and other OOT device will be affected.
cc @bigPYJ1151
335b3a4 to
2d421b7
Compare
|
Hi @majian4work, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, |
2d421b7 to
1c01c49
Compare
Add XPU-specific decode implementation for DeepSeek-V4 MLA sparse attention. Signed-off-by: Ma Jian <jian1.ma@intel.com>
|
merged as most are xpu only change. thanks. |
Summary
Add XPU-specific decode implementation for DeepSeek-V4 MLA sparse attention, including Triton kernels for FP8 KV cache operations.
Changes
forward_xpufor MHC Pre/Post/Fuse processorsTesting
Tested with DeepSeek-V4 inference (prefill + decode) on Intel XPU.
Notes