Skip to content

[kernel] Fix FP8 paged MQA fallback for CUDA graph capture#36250

Open
ZJY0516 wants to merge 4 commits intovllm-project:mainfrom
ZJY0516:fix_v32_fallback
Open

[kernel] Fix FP8 paged MQA fallback for CUDA graph capture#36250
ZJY0516 wants to merge 4 commits intovllm-project:mainfrom
ZJY0516:fix_v32_fallback

Conversation

@ZJY0516
Copy link
Member

@ZJY0516 ZJY0516 commented Mar 6, 2026

Purpose

fp8_paged_mqa_logits_torch is not cudagraph compatible.

(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]   File "/mnt/data1/zjy/code/vllm-src/vllm/model_executor/layers/sparse_attn_indexer.py", line 193, in sparse_attn_indexer
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]     logits = fp8_paged_mqa_logits_torch(
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]   File "/mnt/data1/zjy/code/vllm-src/vllm/utils/deep_gemm.py", line 510, in fp8_paged_mqa_logits_torch
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]     context_len = context_lens[i].item()
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927]                   ^^^^^^^^^^^^^^^^^^^^^^
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927] torch.AcceleratorError: CUDA error: operation not permitted when stream is capturing
(Worker pid=3782129) (Worker_TP1 pid=3782129) ERROR 03-07 11:21:54 [multiproc_executor.py:927] Search for `cudaErrorStreamCaptureUnsupported' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.

This PR adds a triton kernel for this.

Test Plan

VLLM_USE_DEEP_GEMM=0 vllm serve /mnt/data3/DSModels/models/deepseek-ai/DeepSeek-V3.2/ -tp 8 --served-model-name deepseek-ai/DeepSeek-V3.2 --tokenizer-mode deepseek_v32 --enable-auto-tool-choice --tool-call-parser deepseek_v32 --reasoning-parser deepseek_v3

Test Result

curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer EMPTY" \
-d '{
"model": "deepseek-ai/DeepSeek-V3.2",
"messages": [
{
"role": "user",
"content": "Solve this problem step by step: What is 15% of 4800?"
}
],
"max_tokens": 2048
}'
{"id":"chatcmpl-a95500daaf6a3517","object":"chat.completion","created":1772853307,"model":"deepseek-ai/DeepSeek-V3.2","choices":[{"index":0,"message":{"role":"assistant","content":"Alright, let's go step by step.  \n\n---\n\n**Step 1: Understand the question**  \nWe want to find 15% of 4800.  \n\"Percent\" means \"per hundred,\" so 15% means \\( 15 / 100 \\).\n\n---\n\n**Step 2: Convert percentage to decimal**  \n\\[\n15\\% = \\frac{15}{100} = 0.15\n\\]\n\n---\n\n**Step 3: Multiply decimal by the number**  \n\\[\n0.15 \\times 4800\n\\]\n\nFirst, \\( 0.15 = \\frac{15}{100} \\), so:  \n\\[\n0.15 \\times 4800 = \\frac{15}{100} \\times 4800\n\\]\n\n---\n\n**Step 4: Simplify the fraction multiplication**  \n\\[\n\\frac{15 \\times 4800}{100} = 15 \\times 48\n\\]\n(because \\( 4800 \\div 100 = 48 \\))\n\n---\n\n**Step 5: Multiply**  \n\\[\n15 \\times 48 = 720\n\\]\n\n---\n\n**Final answer:**  \n\\[\n\\boxed{720}\n\\]","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":21,"total_tokens":264,"completion_tokens":243,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@mergify mergify bot added nvidia bug Something isn't working labels Mar 6, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the PyTorch fallback implementation for FP8 paged MQA to make it compatible with CUDA graph capture. The changes involve vectorizing the implementation to remove host-device synchronization points and correctly handle the packed FP8 KV cache layout, which improves both correctness and performance. No security vulnerabilities were found.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
ZJY0516 added 2 commits March 7, 2026 00:13
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 changed the title [Bugfix] Fix FP8 paged MQA fallback for CUDA graph capture [kernel] Fix FP8 paged MQA fallback for CUDA graph capture Mar 7, 2026
Copy link
Contributor

@LopezCastroRoberto LopezCastroRoberto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some e2 performance and accuracy numbers?

k_scale_ptr + physical_block_id * stride_ks_blk + offs_k * stride_ks_pos,
mask=token_valid,
other=0.0,
).to(tl.float16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cast to FP16 allowed accuracy-wise? The old PyTorch fallback used FP32 for dequantization.

scale = scale.contiguous().view(torch.float)


logits = torch.full(
(batch_size * next_n, max_model_len),
float("-inf"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean_logits=False is now supported, so we shouldn't have to initialize logits to -inf

@ZJY0516
Copy link
Member Author

ZJY0516 commented Mar 10, 2026

Can you add some e2 performance and accuracy numbers, so we can understand the impact of this PR?

The purpose of this pr is add a fall back for deepgemm and avoid #36519, so the performance is not very important.

@LopezCastroRoberto
Copy link
Contributor

@ZJY0516 I agree to some extent. It was mainly out of curiosity to get a sense of the cost if deepgeem is not installed :)

@ZJY0516
Copy link
Member Author

ZJY0516 commented Mar 10, 2026

will update accuracy test later

@MatthewBonanni
Copy link
Collaborator

Thanks for doing this! Since #36519 was merged, could you update this PR to change the reported CG support back to just UNIFORM_BATCH?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia v1

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants