[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention#17139
[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention#17139vllm-bot merged 5 commits intovllm-project:mainfrom
Conversation
…d output FP8 tensor Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
…uantizing in the flash attention kernel for V1 Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
2 nits, and could we add this case to tests?
csrc/rocm/attention.cu
Outdated
| // NOTE: fp8_out_scale is optional. | ||
| const float* fp8_out_scale_ptr = | ||
| fp8_out_scale | ||
| ? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) |
| fp8_out_scale | ||
| ? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr()) | ||
| : nullptr; | ||
| OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr()); |
There was a problem hiding this comment.
Should the OUTT type be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere
There was a problem hiding this comment.
Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?
There was a problem hiding this comment.
Should the
OUTTtype be fp8 if scale is given? Is that captured automatically? Maybe we could assert this somewhere
This is ensured at https://github.com/vllm-project/vllm/pull/17139/files#diff-79b8261aa73f07cc7450e48c8e14150576656f19ccfb42ba972860092c1f5949R1779-R1786
Also, should tmp_output be the same type as output? So if output is fp8, is tmp_output also fp8?
No, it should be the same type as query, it is used in the internal calculations
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
commit 9f733ff Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Date: Fri Apr 25 22:10:58 2025 +0000 Using static cast Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> commit 2d7dba5 Author: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Date: Thu Apr 24 21:37:16 2025 +0000 An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com>
csrc/rocm/attention.cu
Outdated
|
|
||
| // final write to tmp_out after vout accumulation | ||
| if (warpid == 0) { | ||
| const float out_scale = |
There was a problem hiding this comment.
wondering where out_scale is used here?
There was a problem hiding this comment.
It is actually used in the reduction kernel launched after either of the attention kernels.
The dereferencing here is indeed not needed, but it'll get optimized out. I'll make a note to clean it up
There was a problem hiding this comment.
Could you just remove it in this PR?
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
vllm-project#17139) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
An option to apply fp8 output scale in ROCm custom paged attention and output FP8 tensor
In case a non-None scale tensor is passed to the kernel, the output tensor is expected to be in the current_platform.fp8_dtype() type (float8_fnuz or float8_fn), and the scale is applied to it before storing into an 8-bit type