diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index a582b4b4d7cc..18afe4b7925c 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -1,5 +1,16 @@ #include "cpu_attn_dispatch_generated.h" +// Maps kv_cache_dtype string to Fp8KVCacheDataType enum. +// "auto" -> kAuto(0); "fp8"/"fp8_e4m3" -> kFp8E4M3; "fp8_e5m2" -> kFp8E5M2. +static inline cpu_attention::Fp8KVCacheDataType parse_fp8_kv_dtype( + const std::string& kv_cache_dtype) { + if (kv_cache_dtype == "fp8_e5m2") + return cpu_attention::Fp8KVCacheDataType::kFp8E5M2; + if (kv_cache_dtype == "fp8_e4m3" || kv_cache_dtype == "fp8") + return cpu_attention::Fp8KVCacheDataType::kFp8E4M3; + return cpu_attention::Fp8KVCacheDataType::kAuto; +} + torch::Tensor get_scheduler_metadata( const int64_t num_req, const int64_t num_heads_q, const int64_t num_heads_kv, const int64_t head_dim, @@ -49,7 +60,7 @@ torch::Tensor get_scheduler_metadata( input.enable_kv_split = enable_kv_split; VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() { - CPU_ATTN_DISPATCH(head_dim, isa, [&]() { + CPU_ATTN_DISPATCH(head_dim, isa, 0, [&]() { input.elem_size = sizeof(scalar_t); input.q_buffer_elem_size = sizeof(attn_impl::q_buffer_t); input.logits_buffer_elem_size = sizeof(attn_impl::logits_buffer_t); @@ -72,7 +83,9 @@ void cpu_attn_reshape_and_cache( key_cache, // [num_blocks, num_kv_heads, block_size, head_size] torch::Tensor& value_cache, // [num_blocks, num_kv_heads, block_size, head_size] - const torch::Tensor& slot_mapping, const std::string& isa) { + const torch::Tensor& slot_mapping, const std::string& isa, + const double k_scale = 1.0, const double v_scale = 1.0, + const std::string& kv_cache_dtype = "auto") { TORCH_CHECK_EQ(key.dim(), 3); TORCH_CHECK_EQ(value.dim(), 3); TORCH_CHECK_EQ(key_cache.dim(), 4); @@ -80,18 +93,30 @@ void cpu_attn_reshape_and_cache( TORCH_CHECK_EQ(key.stride(2), 1); TORCH_CHECK_EQ(value.stride(2), 1); + const int64_t kv_cache_idx = + static_cast(parse_fp8_kv_dtype(kv_cache_dtype)); + const bool is_fp8 = (kv_cache_idx != 0); + + if (is_fp8) { + TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte, + "key_cache must be uint8 for FP8 path"); + TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte, + "value_cache must be uint8 for FP8 path"); + TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path"); + TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path"); + } + + const float k_inv = is_fp8 ? 1.0f / static_cast(k_scale) : 0.0f; + const float v_inv = is_fp8 ? 1.0f / static_cast(v_scale) : 0.0f; + const int64_t token_num = key.size(0); - const int64_t key_token_num_stride = key.stride(0); - const int64_t value_token_num_stride = value.stride(0); - const int64_t head_num = value.size(1); - const int64_t key_head_num_stride = key.stride(1); - const int64_t value_head_num_stride = value.stride(1); + const int64_t head_num = key.size(1); + const int64_t head_dim = key.size(2); const int64_t num_blocks = key_cache.size(0); const int64_t num_blocks_stride = key_cache.stride(0); const int64_t cache_head_num_stride = key_cache.stride(1); const int64_t block_size = key_cache.size(2); const int64_t block_size_stride = key_cache.stride(2); - const int64_t head_dim = key.size(-1); cpu_attention::ISA isa_tag = [&]() { if (isa == "amx") { @@ -109,16 +134,24 @@ void cpu_attn_reshape_and_cache( } }(); + if (is_fp8) { + TORCH_CHECK(isa_tag == cpu_attention::ISA::AMX || + isa_tag == cpu_attention::ISA::VEC, + "FP8 KV cache is only supported on x86 (AMX/VEC) ISA"); + } + VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), "cpu_attn_reshape_and_cache", [&]() { - CPU_ATTN_DISPATCH(head_dim, isa_tag, [&]() { + CPU_ATTN_DISPATCH(head_dim, isa_tag, kv_cache_idx, [&]() { + using kv_t = typename attn_impl::kv_cache_t; attn_impl::reshape_and_cache( key.data_ptr(), value.data_ptr(), - key_cache.data_ptr(), value_cache.data_ptr(), - slot_mapping.data_ptr(), token_num, key_token_num_stride, - value_token_num_stride, head_num, key_head_num_stride, - value_head_num_stride, num_blocks, num_blocks_stride, - cache_head_num_stride, block_size, block_size_stride); + reinterpret_cast(key_cache.data_ptr()), + reinterpret_cast(value_cache.data_ptr()), + slot_mapping.data_ptr(), token_num, key.stride(0), + value.stride(0), head_num, key.stride(1), value.stride(1), + num_blocks, num_blocks_stride, cache_head_num_stride, block_size, + block_size_stride, k_inv, v_inv); }); }); } @@ -137,13 +170,26 @@ void cpu_attention_with_kv_cache( const int64_t sliding_window_left, const int64_t sliding_window_right, const torch::Tensor& block_table, // [num_tokens, max_block_num] const double softcap, const torch::Tensor& scheduler_metadata, - const std::optional& s_aux // [num_heads] -) { + const std::optional& s_aux, // [num_heads] + const double k_scale = 1.0, const double v_scale = 1.0, + const std::string& kv_cache_dtype = "auto") { TORCH_CHECK_EQ(query.dim(), 3); TORCH_CHECK_EQ(query.stride(2), 1); TORCH_CHECK_EQ(key_cache.dim(), 4); TORCH_CHECK_EQ(value_cache.dim(), 4); + const int64_t kv_cache_idx = + static_cast(parse_fp8_kv_dtype(kv_cache_dtype)); + const bool is_fp8 = (kv_cache_idx != 0); + if (is_fp8) { + TORCH_CHECK(key_cache.scalar_type() == at::ScalarType::Byte, + "key_cache must be uint8 for FP8 path"); + TORCH_CHECK(value_cache.scalar_type() == at::ScalarType::Byte, + "value_cache must be uint8 for FP8 path"); + TORCH_CHECK(k_scale > 0, "k_scale must be positive for FP8 path"); + TORCH_CHECK(v_scale > 0, "v_scale must be positive for FP8 path"); + } + cpu_attention::AttentionInput input; input.metadata = reinterpret_cast( scheduler_metadata.data_ptr()); @@ -165,25 +211,32 @@ void cpu_attention_with_kv_cache( input.block_table = block_table.data_ptr(); input.alibi_slopes = alibi_slopes.has_value() ? alibi_slopes->data_ptr() : nullptr; - // For now sink must be bf16 input.s_aux = s_aux.has_value() ? s_aux->data_ptr() : nullptr; input.scale = scale; input.causal = causal; input.sliding_window_left = sliding_window_left; input.sliding_window_right = sliding_window_right; if (input.causal) { - // to make boundary calculation easier input.sliding_window_right = 0; } - float softcap_fp32 = softcap; - input.softcap = softcap_fp32; + input.softcap = static_cast(softcap); + + if (is_fp8) { + input.k_scale_fp8 = static_cast(k_scale); + input.v_scale_fp8 = static_cast(v_scale); + TORCH_CHECK(input.metadata->isa == cpu_attention::ISA::AMX || + input.metadata->isa == cpu_attention::ISA::VEC, + "FP8 KV cache is only supported on x86 (AMX/VEC) ISA"); + } VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), "cpu_attention_with_kv_cache", [&]() { - CPU_ATTN_DISPATCH(query.size(2), input.metadata->isa, [&]() { - TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, 0); - cpu_attention::AttentionMainLoop mainloop; - mainloop(&input); - }); + CPU_ATTN_DISPATCH( + query.size(2), input.metadata->isa, kv_cache_idx, [&]() { + TORCH_CHECK_EQ(input.block_size % attn_impl::BlockSizeAlignment, + 0); + cpu_attention::AttentionMainLoop mainloop; + mainloop(&input); + }); }); } diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp index 1c8644d52329..6a0341085dce 100644 --- a/csrc/cpu/cpu_attn_amx.hpp +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -1,6 +1,7 @@ #ifndef CPU_ATTN_AMX_HPP #define CPU_ATTN_AMX_HPP +#include "cpu_attn_fp8.hpp" #include "cpu_attn_impl.hpp" namespace cpu_attention { @@ -21,9 +22,10 @@ typedef struct __tile_config { // 2-2-4 pattern, for 16 < m <= 32 // TILE 0, 1: load A matrix, row num should be 16, m - 16 // TILE 2, 3: load B matrix, row num should be 16 -// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, m - 16, m -// - 16 -template +// TILE 4, 5, 6, 7: store results C matrix, row num should be 16, 16, +// m - 16, m - 16 +// q_buffer_t: A (Q/P) tile type; kv_cache_t: B (K/V cache) tile type. +template class TileGemm224 { public: template @@ -42,13 +44,56 @@ class TileGemm224 { } }; -template <> -class TileGemm224 { +// Dequantize one FP8 tile (AMX_TILE_ROW_NUM rows x 32 cols) to BF16. +template +FORCE_INLINE void deq_tile_amx(const uint8_t* src, c10::BFloat16* dst) { + for (int r = 0; r < AMX_TILE_ROW_NUM; ++r) { + if constexpr (std::is_same_v) { + vec_op::BF16Vec32(src + r * 32, vec_op::fp8_bf16_e4m3_tag{}) + .save(dst + r * 32); + } else { + vec_op::BF16Vec32(src + r * 32, vec_op::fp8_bf16_e5m2_tag{}) + .save(dst + r * 32); + } + } +} + +// For FP8: dequant src into scratch and return scratch. +// For BF16: return src directly (scratch is unused; the compiler elides it). +template +FORCE_INLINE const c10::BFloat16* prepare_b_tile(const kv_cache_t* src, + c10::BFloat16* scratch) { + if constexpr (std::is_same_v || + std::is_same_v) { + deq_tile_amx(reinterpret_cast(src), scratch); + return scratch; + } else { + return reinterpret_cast(src); + } +} + +// Handles both BF16 and FP8 KV cache (2-2-4 pattern). +template +class TileGemm224 { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "kv_cache_t must be BFloat16, Float8_e4m3fn, or Float8_e5m2"); + + static constexpr bool fp8_kv = + std::is_same_v || + std::is_same_v; + + static constexpr int64_t tile_elems = AMX_TILE_BYTES / sizeof(c10::BFloat16); + // BF16 path: scratch_elems=1 so the scratch array is eliminated by the + // compiler. + static constexpr int64_t scratch_elems = fp8_kv ? tile_elems : 1; + public: template FORCE_INLINE static void gemm(const int32_t m_size, c10::BFloat16* __restrict__ a_tile, - c10::BFloat16* __restrict__ b_tile, + kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, @@ -56,6 +101,7 @@ class TileGemm224 { const bool accum_c) { const int32_t k_times = dynamic_k_size / (AMX_TILE_ROW_NUM * 4 / sizeof(c10::BFloat16)); + c10::BFloat16* __restrict__ a_tile_0 = a_tile; c10::BFloat16* __restrict__ a_tile_1 = a_tile + lda * AMX_TILE_ROW_NUM; const int64_t a_tile_stride = [&]() { @@ -70,8 +116,8 @@ class TileGemm224 { } }(); - c10::BFloat16* __restrict__ b_tile_2 = b_tile; - c10::BFloat16* __restrict__ b_tile_3 = [&]() { + kv_cache_t* __restrict__ b_tile_2 = b_tile; + kv_cache_t* __restrict__ b_tile_3 = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { // k_cache is prepacked return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); @@ -106,11 +152,16 @@ class TileGemm224 { _tile_zero(7); } + alignas(64) c10::BFloat16 scratch_2[scratch_elems]; + alignas(64) c10::BFloat16 scratch_3[scratch_elems]; for (int32_t k = 0; k < k_times; ++k) { + const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2); + const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3); + _tile_loadd(0, a_tile_0, a_tile_stride); - _tile_stream_loadd(2, b_tile_2, b_tile_stride); + _tile_stream_loadd(2, const_cast(load_2), b_tile_stride); _tile_dpbf16ps(4, 0, 2); - _tile_stream_loadd(3, b_tile_3, b_tile_stride); + _tile_stream_loadd(3, const_cast(load_3), b_tile_stride); _tile_dpbf16ps(5, 0, 3); _tile_loadd(1, a_tile_1, a_tile_stride); _tile_dpbf16ps(6, 1, 2); @@ -154,13 +205,13 @@ class TileGemm224 { }; // 1-2-2 pattern, for 0 < m <= 16 -// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should be -// m, m -// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row -// num should be 16 -// TILE 6, 7, (6, 7): store results C matrix, row num should be -// m -template +// TILE 0, (1): load A matrix, use extra 1 tile for prefetch, row num should +// be m, m +// TILE 2, 3, (4, 5): load B matrix, use extra 2 tiles for prefetch, row num +// should be 16 +// TILE 6, 7: store results C matrix, row num should be m +// q_buffer_t: A (Q/P) tile type; kv_cache_t: B (K/V cache) tile type. +template class TileGemm122 { public: template @@ -179,13 +230,26 @@ class TileGemm122 { } }; -template <> -class TileGemm122 { +// Handles both BF16 and FP8 KV cache (1-2-2 pattern). +template +class TileGemm122 { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v, + "kv_cache_t must be BFloat16, Float8_e4m3fn, or Float8_e5m2"); + + static constexpr bool fp8_kv = + std::is_same_v || + std::is_same_v; + + static constexpr int64_t tile_elems = AMX_TILE_BYTES / sizeof(c10::BFloat16); + static constexpr int64_t scratch_elems = fp8_kv ? tile_elems : 1; + public: template FORCE_INLINE static void gemm(const int32_t m_size, c10::BFloat16* __restrict__ a_tile, - c10::BFloat16* __restrict__ b_tile, + kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, @@ -215,21 +279,19 @@ class TileGemm122 { } }(); - c10::BFloat16* __restrict__ b_tile_2 = b_tile; - c10::BFloat16* __restrict__ b_tile_3 = [&]() { + kv_cache_t* __restrict__ b_tile_2 = b_tile; + kv_cache_t* __restrict__ b_tile_3 = [&]() { if constexpr (phase == AttentionGemmPhase::QK) { - // k_cache is prepacked return b_tile + (k_size * AMX_TILE_ROW_BYTES / 4); } else if constexpr (phase == AttentionGemmPhase::PV) { - // v_cache is prepacked return b_tile + (block_size * AMX_TILE_ROW_BYTES / 4); } else { TORCH_CHECK(false, "Unreachable"); } }(); - c10::BFloat16* __restrict__ b_tile_4 = + kv_cache_t* __restrict__ b_tile_4 = b_tile_2 + AMX_TILE_BYTES / sizeof(c10::BFloat16); - c10::BFloat16* __restrict__ b_tile_5 = + kv_cache_t* __restrict__ b_tile_5 = b_tile_3 + AMX_TILE_BYTES / sizeof(c10::BFloat16); int64_t b_stride = AMX_TILE_ROW_BYTES; @@ -250,16 +312,25 @@ class TileGemm122 { _tile_zero(7); } + alignas(64) c10::BFloat16 scratch_2[scratch_elems]; + alignas(64) c10::BFloat16 scratch_3[scratch_elems]; + alignas(64) c10::BFloat16 scratch_4[scratch_elems]; + alignas(64) c10::BFloat16 scratch_5[scratch_elems]; for (int32_t k = 0; k < k_group_times; ++k) { + const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2); + const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3); + const c10::BFloat16* load_4 = prepare_b_tile(b_tile_4, scratch_4); + const c10::BFloat16* load_5 = prepare_b_tile(b_tile_5, scratch_5); + _tile_loadd(0, a_tile_0, a_tile_stride); - _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_stream_loadd(2, const_cast(load_2), b_stride); _tile_dpbf16ps(6, 0, 2); - _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_stream_loadd(3, const_cast(load_3), b_stride); _tile_dpbf16ps(7, 0, 3); _tile_loadd(1, a_tile_1, a_tile_stride); - _tile_stream_loadd(4, b_tile_4, b_stride); + _tile_stream_loadd(4, const_cast(load_4), b_stride); _tile_dpbf16ps(6, 1, 4); - _tile_stream_loadd(5, b_tile_5, b_stride); + _tile_stream_loadd(5, const_cast(load_5), b_stride); _tile_dpbf16ps(7, 1, 5); // update ptrs @@ -279,10 +350,13 @@ class TileGemm122 { } if (has_tail) { + const c10::BFloat16* load_2 = prepare_b_tile(b_tile_2, scratch_2); + const c10::BFloat16* load_3 = prepare_b_tile(b_tile_3, scratch_3); + _tile_loadd(0, a_tile_0, a_tile_stride); - _tile_stream_loadd(2, b_tile_2, b_stride); + _tile_stream_loadd(2, const_cast(load_2), b_stride); _tile_dpbf16ps(6, 0, 2); - _tile_stream_loadd(3, b_tile_3, b_stride); + _tile_stream_loadd(3, const_cast(load_3), b_stride); _tile_dpbf16ps(7, 0, 3); } @@ -302,21 +376,25 @@ class TileGemm122 { _tile_loadconfig(&config); } }; + } // namespace -template -class AttentionImpl { +template +class AttentionImpl { + static constexpr bool fp8_kv = + std::is_same_v || + std::is_same_v; + public: using query_t = scalar_t; using q_buffer_t = scalar_t; - using kv_cache_t = scalar_t; + using kv_cache_t = kv_cache_scalar_t; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = scalar_t; constexpr static int64_t BlockSizeAlignment = - AMX_TILE_ROW_BYTES / - sizeof(kv_cache_t); // KV token num unit of QK and PV phases + 32; // AMX_TILE_ROW_NUM = 16 tokens/tile; 32 = 2 tiles constexpr static int64_t HeadDimAlignment = 2 * (AMX_TILE_ROW_BYTES / 4); // headdim num unit of PV phase constexpr static int64_t MaxQHeadNumPerIteration = 32; @@ -324,6 +402,9 @@ class AttentionImpl { constexpr static ISA ISAType = ISA::AMX; constexpr static bool scale_on_logits = true; + float k_scale = 1.0f; + float v_scale = 1.0f; + public: AttentionImpl() : current_q_head_num_(0) { // Use all columns in AMX tiles @@ -332,21 +413,50 @@ class AttentionImpl { ~AttentionImpl() { _tile_release(); } + void init_from_input(const AttentionInput* input) { + if constexpr (fp8_kv) { + k_scale = input->k_scale_fp8; + v_scale = input->v_scale_fp8; + } + } + + float get_output_v_scale() const noexcept { + if constexpr (fp8_kv) { + // AMX dequant places FP8 payload into a BF16 field (exponent bias 127). + // Correction = 2^(127 - FP8_bias): E4M3 bias=7 → 2^120, E5M2 bias=15 → + // 2^112. + constexpr float bias = + std::is_same_v ? 0x1p112f : 0x1p120f; + return v_scale * bias; + } + return 1.0f; + } + template