[kernel] Fix FP8 paged MQA fallback for CUDA graph capture#36250
[kernel] Fix FP8 paged MQA fallback for CUDA graph capture#36250ZJY0516 wants to merge 4 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the PyTorch fallback implementation for FP8 paged MQA to make it compatible with CUDA graph capture. The changes involve vectorizing the implementation to remove host-device synchronization points and correctly handle the packed FP8 KV cache layout, which improves both correctness and performance. No security vulnerabilities were found.
| k_scale_ptr + physical_block_id * stride_ks_blk + offs_k * stride_ks_pos, | ||
| mask=token_valid, | ||
| other=0.0, | ||
| ).to(tl.float16) |
There was a problem hiding this comment.
Is this cast to FP16 allowed accuracy-wise? The old PyTorch fallback used FP32 for dequantization.
scale = scale.contiguous().view(torch.float)
|
|
||
| logits = torch.full( | ||
| (batch_size * next_n, max_model_len), | ||
| float("-inf"), |
There was a problem hiding this comment.
clean_logits=False is now supported, so we shouldn't have to initialize logits to -inf
The purpose of this pr is add a fall back for deepgemm and avoid #36519, so the performance is not very important. |
|
@ZJY0516 I agree to some extent. It was mainly out of curiosity to get a sense of the cost if deepgeem is not installed :) |
|
will update accuracy test later |
|
Thanks for doing this! Since #36519 was merged, could you update this PR to change the reported CG support back to just |
Purpose
fp8_paged_mqa_logits_torchis not cudagraph compatible.This PR adds a triton kernel for this.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.