Skip to content

feat: add DeepSeek-V4 XPU attention decode path#42953

Merged
jikunshang merged 1 commit into
vllm-project:mainfrom
majian4work:dsv4-pr4-attention-decode
Jun 8, 2026
Merged

feat: add DeepSeek-V4 XPU attention decode path#42953
jikunshang merged 1 commit into
vllm-project:mainfrom
majian4work:dsv4-pr4-attention-decode

Conversation

@majian4work
Copy link
Copy Markdown
Contributor

@majian4work majian4work commented May 18, 2026

Summary

Add XPU-specific decode implementation for DeepSeek-V4 MLA sparse attention, including Triton kernels for FP8 KV cache operations.

Changes

  • mhc.py: add forward_xpu for MHC Pre/Post/Fuse processors
  • xpu_qnorm_rope_kv_fp8_insert.py: Triton kernel for fused QK norm + RoPE + FP8 KV insert
  • xpu_sparse_decode_fp8.py: Triton kernel for FP8 sparse MLA decode

Testing

Tested with DeepSeek-V4 inference (prefill + decode) on Intel XPU.

Notes

  • Depends on PR1 (platform guards) and PR3 (FP8 quant) for full functionality
  • Part of DeepSeek-V4 XPU support series (4/4)

@majian4work majian4work requested a review from zyongye as a code owner May 18, 2026 08:44
@mergify mergify Bot added deepseek Related to DeepSeek models intel-gpu Related to Intel GPU v1 labels May 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces XPU support for DeepSeek V4 attention and sparse indexing, including new Triton kernels for normalization, RoPE, and sparse MLA. It also implements a dequantization strategy for FP8 KV caches on XPU. The reviewer identified several performance bottlenecks in the decode path, such as nested Python loops and frequent memory allocations, recommending vectorization and the use of a workspace manager. Additionally, feedback includes correcting hardware-specific FP8 types for XPU and removing debug print statements.

Comment thread vllm/models/deepseek_v4/xpu/xpu_sparse_decode_fp8.py
Comment thread vllm/_xpu_ops.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
Comment on lines +768 to +821
for i in range(batch_size):
seq_len = int(context_lens[i].item())
if seq_len <= 0:
continue
num_pages = cdiv(seq_len, block_size)
pages = block_table[i, :num_pages]
padded_seq_len = num_pages * block_size

# Gather K and scale from block-packed layout
# For each page, K region: [page_start, page_start + bs*dim)
# Scale region: [page_start + bs*dim, page_start + bs*dim + bs*4)
page_starts = pages.to(torch.int64) * block_stride

# K indices: for each page p, for each pos j: page_start + j*dim + d
pos_offsets = torch.arange(block_size, device=device)
dim_offsets = torch.arange(dim, device=device)
# [num_pages, block_size, dim]
k_indices = (
page_starts[:, None, None]
+ pos_offsets[None, :, None] * dim
+ dim_offsets[None, None, :]
)
k_uint8 = cache_flat[k_indices.reshape(-1)].reshape(
padded_seq_len, dim
)
k_fp8 = k_uint8.view(fp8_dtype).to(torch.float32)

# Scale indices: for each page p, for each pos j:
# page_start + bs*dim + j*4 + byte
scale_byte_offsets = torch.arange(4, device=device)
scale_indices = (
page_starts[:, None, None]
+ block_size * dim
+ pos_offsets[None, :, None] * 4
+ scale_byte_offsets[None, None, :]
)
scale_uint8 = cache_flat[scale_indices.reshape(-1)].reshape(
padded_seq_len, 4
)
k_scale = scale_uint8.view(torch.float32) # [padded_seq_len, 1]

for n in range(next_n):
q_i = q[i, n].to(torch.float32) # [H, D]
w_i = weights[i * next_n + n] # [H]

# Compute per-head scores: [H, padded_seq_len]
scores = torch.mm(q_i, k_fp8[:padded_seq_len].T) # [H, S]
scores = scores * k_scale[:padded_seq_len].T # broadcast scale
scores = torch.relu(scores)
# Weight and sum over heads
weighted = (scores * w_i[:, None]).sum(dim=0) # [S]
logits[i * next_n + n, :seq_len] = weighted[:seq_len]

return logits
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested Python loops and per-request torch.mm calls in fp8_paged_mqa_logits will cause significant performance bottlenecks during decode. This implementation should be vectorized or replaced with a Triton kernel to ensure efficient execution on XPU.

Comment on lines +184 to +201
workspace = torch.empty(
(num_tokens * K_total, OUTPUT_DIM), dtype=torch.bfloat16, device=device
)
ws_3d = workspace.view(num_tokens, K_total, OUTPUT_DIM)

# Dequant+gather topk slots from compressed cache
if not swa_only and topk_idx_2d is not None and kv_cache is not None:
topk_flat = topk_idx_2d.reshape(-1).to(torch.int32)
topk_buf = torch.empty(
(num_tokens * max_topk, OUTPUT_DIM), dtype=torch.bfloat16, device=device
)
compressed_block_size = kv_cache.shape[1]
dequant_gather_slots(topk_buf, kv_cache, topk_flat, compressed_block_size)
ws_3d[:, :max_topk, :] = topk_buf.view(num_tokens, max_topk, OUTPUT_DIM)

# Dequant+gather SWA slots
swa_flat = swa_idx_2d.reshape(-1).to(torch.int32)
swa_buf = torch.empty(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Frequent torch.empty allocations in the decode hot path (lines 184, 192, 201) will cause significant performance degradation due to memory management overhead. Use the workspace_manager to obtain pre-allocated buffers or manage a persistent workspace for these operations.

Comment on lines +256 to +261
for t_idx in range(num_tokens):
tlen = int(topk_lens[t_idx].item())
slen = int(swa_lens[t_idx].item())
combined_indices[t_idx, tlen:tlen + slen] = (
swa_ws_indices[t_idx, :slen]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Python loop over num_tokens to pack indices is inefficient and will become a bottleneck as the batch size increases. This operation should be vectorized using PyTorch operations like torch.scatter_ or advanced indexing.

num_heads = q.shape[1]

# Allocate temp buffer for RoPE-applied KV
kv_roped = torch.empty_like(kv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Allocating kv_roped using torch.empty_like in the hot path will lead to performance overhead. Consider using a pre-allocated workspace or the workspace_manager.

@majian4work
Copy link
Copy Markdown
Contributor Author

@jikunshang @xinyu-intel @xuechendi @wuxun-zhang Please help to take a review.

Comment thread vllm/v1/attention/ops/deepseek_v4_ops/xpu_sparse_mla_bf16.py Outdated
Comment thread vllm/model_executor/layers/deepseek_v4_attention.py Outdated
@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 18, 2026

Thank you for your contribution. However, I want to hold this from merging due to this RFC #42770

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 2c3b796 to 3906267 Compare May 19, 2026 02:33
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 3906267 to 9d681e0 Compare May 19, 2026 05:32
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @majian4work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 19, 2026
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 9d681e0 to e03711e Compare May 19, 2026 06:19
@mergify mergify Bot removed the needs-rebase label May 19, 2026
Comment thread vllm/models/deepseek_v4/attention.py Outdated
Comment thread vllm/model_executor/layers/sparse_attn_indexer.py Outdated
Comment thread vllm/model_executor/layers/sparse_attn_indexer.py Outdated
Comment thread vllm/_xpu_ops.py Outdated
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from e03711e to 18cc072 Compare May 19, 2026 08:15
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @majian4work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 23, 2026
@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 24, 2026

I think we can go proceed this PR. Could you take a look at current dsv4 file structure and create a separate folder exclusively to intel/xpu? Thanks

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 18cc072 to ffe42a9 Compare May 25, 2026 03:15
@mergify mergify Bot removed the needs-rebase label May 25, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 25, 2026

Hi @majian4work, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from ffe42a9 to 6ad6803 Compare May 25, 2026 03:32
Comment thread vllm/models/deepseek_v4/attention.py
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 27, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @majian4work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 27, 2026
@jikunshang
Copy link
Copy Markdown
Member

@zyongye ci failed on CPU, which is an existing bug after adding dequant_gather_k_cutedsl.py, how to fix it?

make a patch here #43773 but I feel lazy import is not final solution.

@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 27, 2026

@zyongye ci failed on CPU, which is an existing bug after adding dequant_gather_k_cutedsl.py, how to fix it?

make a patch here #43773 but I feel lazy import is not final solution.

Yea I think we will need to import cutedsl only in cuda.

@zyongye zyongye requested a review from WoosukKwon May 27, 2026 16:16
@zyongye
Copy link
Copy Markdown
Member

zyongye commented May 27, 2026

Also cc @WoosukKwon

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 7d35541 to 45f5043 Compare May 28, 2026 00:40
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 28, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @majian4work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 28, 2026
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 45f5043 to 757ec31 Compare May 28, 2026 08:24
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 28, 2026

Hi @majian4work, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify mergify Bot removed the needs-rebase label May 28, 2026
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 757ec31 to eae4f20 Compare May 28, 2026 08:48
@jikunshang
Copy link
Copy Markdown
Member

@zyongye @WoosukKwon can you take another look?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 4, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @majian4work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jun 4, 2026
@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch 2 times, most recently from 065c5bb to 78cf322 Compare June 5, 2026 07:47
@mergify mergify Bot removed the needs-rebase label Jun 5, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 5, 2026

Hi @majian4work, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 78cf322 to 335b3a4 Compare June 5, 2026 08:14
Copy link
Copy Markdown
Member

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Comment thread vllm/models/deepseek_v4/compressor.py Outdated
# negligible overhead for small scale tensors) ensures the kernel sees
# matching dtypes and correctly enters the block-quant path with the
# actual group_size derived from scale tensor shapes.
if scale.dtype == torch.float8_e8m0fnu:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be fixed in kernel side after vllm-project/vllm-xpu-kernels#398
cc @Yejing-Lai

elif current_platform.is_xpu():
from .xpu.model import DeepseekV4ForCausalLM # type: ignore[assignment]
from .xpu.mtp import DeepSeekV4MTP # type: ignore[assignment]
else:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not certain whether cpu and other OOT device will be affected.
cc @bigPYJ1151

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 335b3a4 to 2d421b7 Compare June 8, 2026 03:06
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 8, 2026

Hi @majian4work, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

@majian4work majian4work force-pushed the dsv4-pr4-attention-decode branch from 2d421b7 to 1c01c49 Compare June 8, 2026 03:26
Add XPU-specific decode implementation for DeepSeek-V4 MLA sparse attention.

Signed-off-by: Ma Jian <jian1.ma@intel.com>
@jikunshang
Copy link
Copy Markdown
Member

merged as most are xpu only change. thanks.

@jikunshang jikunshang merged commit eebce65 into vllm-project:main Jun 8, 2026
64 checks passed
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models intel-gpu Related to Intel GPU ready ONLY add when PR is ready to merge/full CI is needed v1 verified Run pre-commit for new contributors without triggering other tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants