Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
522c64d
Initial fallback fp8 kv cache implementation
tianmu-li Apr 1, 2026
0aaf9a1
WIP
tianmu-li Apr 2, 2026
30c30ac
fp8->fp16->fp32 2-stage implementation
tianmu-li Apr 6, 2026
c8727d2
Cleanup
tianmu-li Apr 6, 2026
e329db1
Add AMX and fp8_e5m2
tianmu-li Apr 7, 2026
b80aa3a
WIP AMX improvement
tianmu-li Apr 8, 2026
e57eed4
Move scale multiplication for k and v to matmul results; remove dead …
tianmu-li Apr 8, 2026
9947fc7
Move scale multiplication for AMX to matmul results
tianmu-li Apr 8, 2026
1a07052
Direct fp8->bf16 conversion for amx
tianmu-li Apr 8, 2026
ea5cabe
Cleanup; always use AMX for fp8 attn when available
tianmu-li Apr 9, 2026
7e3af61
Comment cleanup
tianmu-li Apr 9, 2026
d1fac47
Remove benchmark file
tianmu-li Apr 9, 2026
f8578d1
Merge remote-tracking branch 'origin/main' into tianmu/cpu_fp8_attn
tianmu-li Apr 9, 2026
b47a6b2
Simplify fp8 function calls
tianmu-li Apr 9, 2026
1088172
Use template calls for fp8 cpu variants; address review comments
tianmu-li Apr 21, 2026
9232093
Merge remote-tracking branch 'origin/main' into tianmu/cpu_fp8_attn
tianmu-li Apr 21, 2026
0e1ce52
Simplify TileGemm structure for fp8
tianmu-li Apr 21, 2026
ea73ceb
Use kv_cache_scalar_t for non-supported fp8 platforms; cleanup
tianmu-li Apr 22, 2026
f9abd40
Unify FP8/non-FP8 dispatch; fix FP8 VEC on non-AMX x86
tianmu-li Apr 22, 2026
e6d3977
Pass k_scale, v_scale, kv_cache_dtype unconditionally to C++ backend
tianmu-li Apr 22, 2026
6f417e5
Merge remote-tracking branch 'origin/main' into tianmu/cpu_fp8_attn
tianmu-li Apr 22, 2026
9a78530
WIP
tianmu-li Apr 24, 2026
5f57d86
WIP
tianmu-li Apr 24, 2026
5223ce4
WIP cleanup
tianmu-li Apr 24, 2026
7afde8c
cpu fp8 attn: guard pack_kv_cache, dedup quant fn dispatch, add featu…
tianmu-li Apr 24, 2026
a1196e6
cpu fp8 attn: revert kv_cache_scalar_t rename in NEON/VXE/VEC16 helpe…
tianmu-li Apr 24, 2026
0fe3c63
Comment update
tianmu-li Apr 24, 2026
7a5e97e
Minor syntax
tianmu-li Apr 25, 2026
9f50e46
Fix kv_cache_scalar_t/kv_cache_t namings
tianmu-li Apr 25, 2026
1019419
Disable fp8 attn for avx2
tianmu-li Apr 25, 2026
0fe182e
Merge branch 'main' into tianmu/cpu_fp8_attn
tianmu-li Apr 25, 2026
05bb143
Truly disable fp8 attn for avx-2
tianmu-li Apr 25, 2026
2a9314b
Merge tests for fp8 into existing ones; restore unintended changes
tianmu-li Apr 28, 2026
5a14af8
Merge remote-tracking branch 'origin/main' into tianmu/cpu_fp8_attn
tianmu-li Apr 28, 2026
d67d415
Fix pre-commit error
tianmu-li Apr 28, 2026
25c9d46
Removed unintended changes
tianmu-li Apr 28, 2026
b53bb95
Remove unintended changes
tianmu-li Apr 28, 2026
d44d5cc
Gate fp8 kv cache to x86 only
tianmu-li Apr 28, 2026
1cf96c8
Typo
tianmu-li Apr 28, 2026
a0ff62e
Guard get_output_v_scale behind fp8_kv
tianmu-li Apr 29, 2026
a79e132
Add missing k_inv and v_inv for neon BFMMLA
tianmu-li Apr 29, 2026
38a7aec
Add fp8_e4m3 and fp8_e5m2 tags for arm and vxe
tianmu-li Apr 29, 2026
ec2a69f
Merge branch 'main' into tianmu/cpu_fp8_attn
bigPYJ1151 Apr 29, 2026
d33be6f
Merge branch 'main' into tianmu/cpu_fp8_attn
bigPYJ1151 Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 78 additions & 25 deletions csrc/cpu/cpu_attn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
#include "cpu_attn_dispatch_generated.h"

// Maps kv_cache_dtype string to Fp8KVCacheDataType enum.
// "auto" -> kAuto(0); "fp8"/"fp8_e4m3" -> kFp8E4M3; "fp8_e5m2" -> kFp8E5M2.
static inline cpu_attention::Fp8KVCacheDataType parse_fp8_kv_dtype(
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "fp8_e5m2")
return cpu_attention::Fp8KVCacheDataType::kFp8E5M2;
if (kv_cache_dtype == "fp8_e4m3" || kv_cache_dtype == "fp8")
return cpu_attention::Fp8KVCacheDataType::kFp8E4M3;
return cpu_attention::Fp8KVCacheDataType::kAuto;
}

torch::Tensor get_scheduler_metadata(
const int64_t num_req, const int64_t num_heads_q,
const int64_t num_heads_kv, const int64_t head_dim,
Expand Down Expand Up @@ -49,7 +60,7 @@ torch::Tensor get_scheduler_metadata(
input.enable_kv_split = enable_kv_split;

VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
CPU_ATTN_DISPATCH(head_dim, isa, [&]() {
CPU_ATTN_DISPATCH(head_dim, isa, 0, [&]() {
input.elem_size = sizeof(scalar_t);
input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t);
input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t);
Expand All @@ -72,26 +83,40 @@ void cpu_attn_reshape_and_cache(
key_cache, // [num_blocks, num_kv_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const torch::Tensor& slot_mapping, const std::string& isa) {
const torch::Tensor& slot_mapping, const std::string& isa,
const double k_scale = 1.0, const double v_scale = 1.0,
const std::string& kv_cache_dtype = "auto") {
TORCH_CHECK_EQ(key.dim(), 3);
TORCH_CHECK_EQ(value.dim(), 3);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);
TORCH_CHECK_EQ(key.stride(2), 1);
TORCH_CHECK_EQ(value.stride(2), 1);

const int64_t kv_cache_idx =
static_cast<int64_t>(parse_fp8_kv_dtype(kv_cache_dtype));
const bool is_fp8 = (kv_cache_idx != 0);

if (is_fp8) {
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte,
"key_cache must be uint8 for FP8 path");
TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte,
"value_cache must be uint8 for FP8 path");
TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path");
TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path");
}

const float k_inv = is_fp8 ? 1.0f / static_cast<float>(k_scale) : 0.0f;
const float v_inv = is_fp8 ? 1.0f / static_cast<float>(v_scale) : 0.0f;

const int64_t token_num = key.size(0);
const int64_t key_token_num_stride = key.stride(0);
const int64_t value_token_num_stride = value.stride(0);
const int64_t head_num = value.size(1);
const int64_t key_head_num_stride = key.stride(1);
const int64_t value_head_num_stride = value.stride(1);
const int64_t head_num = key.size(1);
const int64_t head_dim = key.size(2);
const int64_t num_blocks = key_cache.size(0);
const int64_t num_blocks_stride = key_cache.stride(0);
const int64_t cache_head_num_stride = key_cache.stride(1);
const int64_t block_size = key_cache.size(2);
const int64_t block_size_stride = key_cache.stride(2);
const int64_t head_dim = key.size(-1);

cpu_attention::ISA isa_tag = [&]() {
if (isa == "amx") {
Expand All @@ -109,16 +134,24 @@ void cpu_attn_reshape_and_cache(
}
}();

if (is_fp8) {
TORCH_CHECK(isa_tag == cpu_attention::ISA::AMX ||
isa_tag == cpu_attention::ISA::VEC,
"FP8 KV cache is only supported on x86 (AMX/VEC) ISA");
}

VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() {
CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() {
CPU_ATTN_DISPATCH(head_dim, isa_tag, kv_cache_idx, [&]() {
using kv_t = typename attn_impl::kv_cache_t;
attn_impl::reshape_and_cache(
Comment thread
bigPYJ1151 marked this conversation as resolved.
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), token_num, key_token_num_stride,
value_token_num_stride, head_num, key_head_num_stride,
value_head_num_stride, num_blocks, num_blocks_stride,
cache_head_num_stride, block_size, block_size_stride);
reinterpret_cast<kv_t*>(key_cache.data_ptr()),
reinterpret_cast<kv_t*>(value_cache.data_ptr()),
slot_mapping.data_ptr<int64_t>(), token_num, key.stride(0),
value.stride(0), head_num, key.stride(1), value.stride(1),
num_blocks, num_blocks_stride, cache_head_num_stride, block_size,
block_size_stride, k_inv, v_inv);
});
});
}
Expand All @@ -137,13 +170,26 @@ void cpu_attention_with_kv_cache(
const int64_t sliding_window_left, const int64_t sliding_window_right,
const torch::Tensor& block_table, // [num_tokens, max_block_num]
const double softcap, const torch::Tensor& scheduler_metadata,
const std::optional<torch::Tensor>& s_aux // [num_heads]
) {
const std::optional<torch::Tensor>& s_aux, // [num_heads]
const double k_scale = 1.0, const double v_scale = 1.0,
const std::string& kv_cache_dtype = "auto") {
TORCH_CHECK_EQ(query.dim(), 3);
TORCH_CHECK_EQ(query.stride(2), 1);
TORCH_CHECK_EQ(key_cache.dim(), 4);
TORCH_CHECK_EQ(value_cache.dim(), 4);

const int64_t kv_cache_idx =
static_cast<int64_t>(parse_fp8_kv_dtype(kv_cache_dtype));
const bool is_fp8 = (kv_cache_idx != 0);
if (is_fp8) {
TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte,
"key_cache must be uint8 for FP8 path");
TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte,
"value_cache must be uint8 for FP8 path");
TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path");
TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path");
}

cpu_attention::AttentionInput input;
input.metadata = reinterpret_cast<cpu_attention::AttentionMetadata*>(
scheduler_metadata.data_ptr());
Expand All @@ -165,25 +211,32 @@ void cpu_attention_with_kv_cache(
input.block_table = block_table.data_ptr<int32_t>();
input.alibi_slopes =
alibi_slopes.has_value() ? alibi_slopes->data_ptr<float>() : nullptr;
// For now sink must be bf16
input.s_aux = s_aux.has_value() ? s_aux->data_ptr<c10::BFloat16>() : nullptr;
input.scale = scale;
input.causal = causal;
input.sliding_window_left = sliding_window_left;
input.sliding_window_right = sliding_window_right;
if (input.causal) {
// to make boundary calculation easier
input.sliding_window_right = 0;
}
float softcap_fp32 = softcap;
input.softcap = softcap_fp32;
input.softcap = static_cast<float>(softcap);

if (is_fp8) {
input.k_scale_fp8 = static_cast<float>(k_scale);
input.v_scale_fp8 = static_cast<float>(v_scale);
TORCH_CHECK(input.metadata->isa == cpu_attention::ISA::AMX ||
input.metadata->isa == cpu_attention::ISA::VEC,
"FP8 KV cache is only supported on x86 (AMX/VEC) ISA");
}

VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "cpu_attention_with_kv_cache", [&]() {
CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
CPU_ATTN_DISPATCH(
query.size(2), input.metadata->isa, kv_cache_idx, [&]() {
TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment,
0);
cpu_attention::AttentionMainLoop<attn_impl> mainloop;
mainloop(&input);
});
Comment thread
bigPYJ1151 marked this conversation as resolved.
});
}
Loading
Loading