[CHUNK_PREFILL] fp8kv cache#128
Conversation
YizhouZ
commented
Jan 25, 2026
- add fp8kv cache
- currently test kv scales are set to 1.0
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR adds support for FP8 KV cache quantization to the flash attention implementation. The changes enable storing key and value tensors in float8 formats (e4m3fn and e5m2) with descaling during computation to improve memory efficiency.
Changes:
- Added k_scale and v_scale parameters throughout the attention pipeline to support FP8 KV cache descaling
- Extended type system to handle separate Q and K/V data types via CutlassQKType struct
- Added test coverage for FP8 KV cache variants in the test suite
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| vllm_xpu_kernels/flash_attn_interface.py | Added k_descale and v_descale parameters with default value 1.0, added validation for FA2 KV descaling support |
| tests/flash_attn/test_flash_attn_varlen_func.py | Added FP8KV test parameter and is_fp8kv flag to test FP8 quantized key/value caches |
| csrc/xpu/attn/xe_2/paged_decode_xe2.cpp | Renamed CutlassType to CutlassDType for consistency |
| csrc/xpu/attn/xe_2/paged_decode_utils.hpp | Updated function signatures to use CutlassDType |
| csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in | Updated template to use CutlassDType |
| csrc/xpu/attn/xe_2/paged_decode_extern.hpp | Updated extern template declarations to use CutlassDType |
| csrc/xpu/attn/xe_2/paged_decode.hpp | Updated dispatch implementation to use CutlassDType |
| csrc/xpu/attn/xe_2/fmha_xe2.h | Added k_scale and v_scale parameters to function signature |
| csrc/xpu/attn/xe_2/fmha_xe2.cpp | Added k_scale and v_scale parameters and updated to use CutlassQKType |
| csrc/xpu/attn/xe_2/fmha_utils.hpp | Refactored type system with CutlassDType enum, CutlassQKType struct, and FP8 support |
| csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp | Added FP8 KV cache descaling logic in mainloop |
| csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp | Updated function signatures to use CutlassQKType |
| csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in | Updated template to use CutlassQKType |
| csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp | Updated extern template declarations to use CutlassQKType |
| csrc/xpu/attn/xe_2/chunk_prefill.hpp | Added k_scale/v_scale to args, implemented dispatch logic for FP8 KV types |
| csrc/xpu/attn/attn_interface.h | Added k_scale and v_scale parameters to interface |
| csrc/xpu/attn/attn_interface.cpp | Threaded k_scale and v_scale parameters through implementation |
| csrc/flash_attn/flash_api.cpp | Added k_scale and v_scale parameters, added FP8 KV dtype validation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if constexpr (Fp8KV) { | ||
| for (int i = 0; i < tArV.size(); ++i) { | ||
| tArV(i) = static_cast<ElementQ>( | ||
| params.scale_k * static_cast<float>(tArV(i))); |
There was a problem hiding this comment.
The descaling operation uses params.scale_k instead of params.scale_v for the V tensor. This should use params.scale_v since it's processing the value cache, not the key cache.
| params.scale_k * static_cast<float>(tArV(i))); | |
| params.scale_v * static_cast<float>(tArV(i))); |
There was a problem hiding this comment.
seems copilot is correct. also, can it be vectorized?
There was a problem hiding this comment.
fixed & added PV part unroll.
| if k_descale is None: | ||
| k_descale = 1.0 | ||
| if v_descale is None: | ||
| v_descale = 1.0 |
There was a problem hiding this comment.
q, k, v descale should be torch device tensor with shape [num_sequences, num_kv_heads]
There was a problem hiding this comment.
good point, but I think we can start with static scale for now.
ref. vllm-project/vllm#30141
There was a problem hiding this comment.
correct. We can start with static per tensor in the kernel implementation. But from the flash_attn API level, we need to unify them into the same shape. Consider the following check:
>>> a = torch.tensor([0.1]).to("xpu")
>>> a = a.expand((10, 10))
>>> a.size()
torch.Size([10, 10])
>>> a.stride()
(0, 0)
>>> a.untyped_storage().nbytes()
4
| "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " | ||
| "softmax_scale, Tensor? softmax_sink, bool zero_tensors, " | ||
| "int max_seqlen_q, int max_seqlen_k, float p_dropout, float k_scale, " | ||
| "float v_scale, " |
| half_t, | ||
| half_t, | ||
| half_t>::kernel_dispatch(queue, args); | ||
| if (cuQKType.q_type == CutlassDType::half) { |
There was a problem hiding this comment.
I prefer to make Q/KV type also template here in the future. I feel current change will increase compile time a little bit.
|
|
||
| // User-facing arguments | ||
| struct Arguments { | ||
| ElementS const scale; |
There was a problem hiding this comment.
what's scale here for?