Skip to content

[CHUNK_PREFILL] fp8kv cache#128

Merged
YizhouZ merged 5 commits intovllm-project:mainfrom
YizhouZ:dev/fp8kv
Jan 26, 2026
Merged

[CHUNK_PREFILL] fp8kv cache#128
YizhouZ merged 5 commits intovllm-project:mainfrom
YizhouZ:dev/fp8kv

Conversation

@YizhouZ
Copy link
Copy Markdown
Collaborator

@YizhouZ YizhouZ commented Jan 25, 2026

  • add fp8kv cache
  • currently test kv scales are set to 1.0

Copilot AI review requested due to automatic review settings January 25, 2026 15:38
@YizhouZ YizhouZ changed the title fp8kv cache [CHUNK_PREFILL] fp8kv cache Jan 25, 2026
@YizhouZ YizhouZ requested review from baodii, jikunshang and xinyu-intel and removed request for Copilot and jikunshang January 25, 2026 15:38
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for FP8 KV cache quantization to the flash attention implementation. The changes enable storing key and value tensors in float8 formats (e4m3fn and e5m2) with descaling during computation to improve memory efficiency.

Changes:

  • Added k_scale and v_scale parameters throughout the attention pipeline to support FP8 KV cache descaling
  • Extended type system to handle separate Q and K/V data types via CutlassQKType struct
  • Added test coverage for FP8 KV cache variants in the test suite

Reviewed changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
vllm_xpu_kernels/flash_attn_interface.py Added k_descale and v_descale parameters with default value 1.0, added validation for FA2 KV descaling support
tests/flash_attn/test_flash_attn_varlen_func.py Added FP8KV test parameter and is_fp8kv flag to test FP8 quantized key/value caches
csrc/xpu/attn/xe_2/paged_decode_xe2.cpp Renamed CutlassType to CutlassDType for consistency
csrc/xpu/attn/xe_2/paged_decode_utils.hpp Updated function signatures to use CutlassDType
csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in Updated template to use CutlassDType
csrc/xpu/attn/xe_2/paged_decode_extern.hpp Updated extern template declarations to use CutlassDType
csrc/xpu/attn/xe_2/paged_decode.hpp Updated dispatch implementation to use CutlassDType
csrc/xpu/attn/xe_2/fmha_xe2.h Added k_scale and v_scale parameters to function signature
csrc/xpu/attn/xe_2/fmha_xe2.cpp Added k_scale and v_scale parameters and updated to use CutlassQKType
csrc/xpu/attn/xe_2/fmha_utils.hpp Refactored type system with CutlassDType enum, CutlassQKType struct, and FP8 support
csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp Added FP8 KV cache descaling logic in mainloop
csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp Updated function signatures to use CutlassQKType
csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in Updated template to use CutlassQKType
csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp Updated extern template declarations to use CutlassQKType
csrc/xpu/attn/xe_2/chunk_prefill.hpp Added k_scale/v_scale to args, implemented dispatch logic for FP8 KV types
csrc/xpu/attn/attn_interface.h Added k_scale and v_scale parameters to interface
csrc/xpu/attn/attn_interface.cpp Threaded k_scale and v_scale parameters through implementation
csrc/flash_attn/flash_api.cpp Added k_scale and v_scale parameters, added FP8 KV dtype validation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if constexpr (Fp8KV) {
for (int i = 0; i < tArV.size(); ++i) {
tArV(i) = static_cast<ElementQ>(
params.scale_k * static_cast<float>(tArV(i)));
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The descaling operation uses params.scale_k instead of params.scale_v for the V tensor. This should use params.scale_v since it's processing the value cache, not the key cache.

Suggested change
params.scale_k * static_cast<float>(tArV(i)));
params.scale_v * static_cast<float>(tArV(i)));

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems copilot is correct. also, can it be vectorized?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed & added PV part unroll.

Comment on lines +66 to +69
if k_descale is None:
k_descale = 1.0
if v_descale is None:
v_descale = 1.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, but I think we can start with static scale for now.
ref. vllm-project/vllm#30141

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. We can start with static per tensor in the kernel implementation. But from the flash_attn API level, we need to unify them into the same shape. Consider the following check:

>>> a = torch.tensor([0.1]).to("xpu")
>>> a = a.expand((10, 10))
>>> a.size()
torch.Size([10, 10])
>>> a.stride()
(0, 0)
>>> a.untyped_storage().nbytes()
4

Copy link
Copy Markdown
Collaborator

@baodii baodii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

"int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
"softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float k_scale, "
"float v_scale, "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can reformat here.

half_t,
half_t,
half_t>::kernel_dispatch(queue, args);
if (cuQKType.q_type == CutlassDType::half) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make Q/KV type also template here in the future. I feel current change will increase compile time a little bit.


// User-facing arguments
struct Arguments {
ElementS const scale;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's scale here for?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax scale

@YizhouZ YizhouZ merged commit a981594 into vllm-project:main Jan 26, 2026
8 checks passed
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Signed-off-by: Yizhou Wang <yizhou.wang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants