diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index cea38e1ee..cef5eb240 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -22,6 +22,8 @@ std::vector mha_varlen_fwd( int max_seqlen_q, int max_seqlen_k, float p_dropout, + float k_scale, + float v_scale, float softmax_scale, std::optional& softmax_sink_, const bool zero_tensors, @@ -32,14 +34,23 @@ std::vector mha_varlen_fwd( const bool return_softmax, std::optional gen_) { auto q_type = q.scalar_type(); + auto k_type = k.scalar_type(); TORCH_CHECK( q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "VLLM Kernel XPU only supports fp16 and bf16 type"); TORCH_CHECK( - k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK( - v.scalar_type() == q_type, "query and value must have the same dtype"); + v.scalar_type() == k_type, "key and value must have the same dtype"); + bool is_fp8kv = false; + if (k_type == at::ScalarType::Float8_e5m2 || + k_type == at::ScalarType::Float8_e4m3fn) { + is_fp8kv = true; + } else { + TORCH_CHECK( + k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK( + v.scalar_type() == q_type, "query and value must have the same dtype"); + } CHECK_DEVICE(q); CHECK_DEVICE(k); @@ -94,7 +105,7 @@ std::vector mha_varlen_fwd( bool is_local = (window_size_left != -1) | (window_size_right != -1); bool is_sink = softmax_sink_.has_value(); - if (max_seqlen_q > 1 || is_local || !is_paged) { + if (max_seqlen_q > 1 || is_local || !is_paged || is_fp8kv) { at::Tensor seqlens_k = is_paged ? *seqused_k : cu_seqlens_k; cutlass_chunk_prefill_interface( @@ -108,6 +119,8 @@ std::vector mha_varlen_fwd( seqlens_k, max_seqlen_q, max_seqlen_k, + k_scale, + v_scale, softmax_scale, softmax_sink_, window_size_left, @@ -182,8 +195,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cu_seqlens_q, " "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? " "block_table, Tensor? alibi_slopes, " - "int max_seqlen_q, int max_seqlen_k, float p_dropout, float " - "softmax_scale, Tensor? softmax_sink, bool zero_tensors, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float k_scale, " + "float v_scale, " + "float softmax_scale, Tensor? softmax_sink, bool zero_tensors, " "bool is_causal, int window_size_left, int window_size_right, float " "softcap, bool return_softmax, " "Generator? gen) -> Tensor[]"); diff --git a/csrc/xpu/attn/attn_interface.cpp b/csrc/xpu/attn/attn_interface.cpp index 59375fde5..2c087be49 100644 --- a/csrc/xpu/attn/attn_interface.cpp +++ b/csrc/xpu/attn/attn_interface.cpp @@ -17,6 +17,8 @@ void cutlass_chunk_prefill_interface( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, @@ -40,6 +42,8 @@ void cutlass_chunk_prefill_interface( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + k_scale, + v_scale, sm_scale, sm_sink_, window_size_left, diff --git a/csrc/xpu/attn/attn_interface.h b/csrc/xpu/attn/attn_interface.h index b35cf6faa..91b96bc05 100644 --- a/csrc/xpu/attn/attn_interface.h +++ b/csrc/xpu/attn/attn_interface.h @@ -10,6 +10,8 @@ void cutlass_chunk_prefill_interface( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, diff --git a/csrc/xpu/attn/xe_2/chunk_prefill.hpp b/csrc/xpu/attn/xe_2/chunk_prefill.hpp index bb9673237..316be6e93 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill.hpp @@ -36,6 +36,8 @@ struct chunk_prefill_args_t { int max_keys; int total_seqlen_q; int total_seqlen_k; + float k_scale = 1.0; + float v_scale = 1.0; float sm_scale; void* sm_sink; int batch_size; @@ -145,6 +147,8 @@ struct KernelLauncher { stride_O, reinterpret_cast(args.sm_sink)}, {args.sm_scale, + args.k_scale, + args.v_scale, static_cast(args.block_table), args.block_size, args.max_blocks_per_seq, @@ -315,35 +319,109 @@ struct FMHAConfig { template void policy_dispatch_impl( - sycl::queue& queue, CutlassType cuType, const chunk_prefill_args_t& args) { + sycl::queue& queue, + CutlassQKType& cuQKType, + const chunk_prefill_args_t& args) { const int PipelineStages = 2; - if (cuType == CutlassType::half) { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink, - half_t, - half_t, - half_t, - half_t>::kernel_dispatch(queue, args); + if (cuQKType.q_type == CutlassDType::half) { + if (cuQKType.k_type == CutlassDType::half) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + half_t, + half_t, + half_t, + half_t>::kernel_dispatch(queue, args); + } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + half_t, + float_e4m3_t, + float_e4m3_t, + half_t>::kernel_dispatch(queue, args); + } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + half_t, + float_e5m2_t, + float_e5m2_t, + half_t>::kernel_dispatch(queue, args); + } } else { - return FMHAConfig< - typename chunk_policy::ShapeQK, - typename chunk_policy::ShapePV, - typename chunk_policy::ShapeOut, - typename chunk_policy::SubgroupLayoutQK, - void, - PipelineStages, - Paged, - Causal, - Local, - Sink>::kernel_dispatch(queue, args); + if (cuQKType.k_type == CutlassDType::bfloat16) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + bfloat16_t, + bfloat16_t, + bfloat16_t, + bfloat16_t>::kernel_dispatch(queue, args); + } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + bfloat16_t, + float_e4m3_t, + float_e4m3_t, + bfloat16_t>::kernel_dispatch(queue, args); + } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { + return FMHAConfig< + typename chunk_policy::ShapeQK, + typename chunk_policy::ShapePV, + typename chunk_policy::ShapeOut, + typename chunk_policy::SubgroupLayoutQK, + void, + PipelineStages, + Paged, + Causal, + Local, + Sink, + bfloat16_t, + float_e5m2_t, + float_e5m2_t, + bfloat16_t>::kernel_dispatch(queue, args); + } } } diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp b/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp index 170f5377d..619be7dc7 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp @@ -26,7 +26,7 @@ extern template void \ policy_dispatch_impl( \ sycl::queue & queue, \ - CutlassType cuType, \ + CutlassQKType & cuQKType, \ const chunk_prefill_args_t& args); // Generate all 16 bool combinations for a given policy using nested macros diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in b/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in index b09bfe217..a9a575c37 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in +++ b/csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in @@ -21,7 +21,7 @@ using namespace cute; static_cast(IMPL_KISLOCAL), \ static_cast(IMPL_KISSINK)>( \ sycl::queue & queue, \ - CutlassType cuType, \ + CutlassQKType& cuQKType, \ const chunk_prefill_args_t& args); INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp b/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp index 74dcd51d1..1c0a93030 100644 --- a/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp +++ b/csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp @@ -4,22 +4,25 @@ using namespace cute; template void policy_dispatch_func( - sycl::queue& queue, CutlassType cuType, const chunk_prefill_args_t& args) { - policy_dispatch_impl(queue, cuType, args); + sycl::queue& queue, + CutlassQKType& cuQKType, + const chunk_prefill_args_t& args) { + policy_dispatch_impl(queue, cuQKType, args); } template void policy_dispatch_func( sycl::queue& queue, - CutlassType cuType, + CutlassQKType& cuQKType, const chunk_prefill_args_t& args, bool b, Ts... ts) { if (b) { - policy_dispatch_func(queue, cuType, args, ts...); + policy_dispatch_func( + queue, cuQKType, args, ts...); } else { policy_dispatch_func( - queue, cuType, args, ts...); + queue, cuQKType, args, ts...); } } @@ -34,6 +37,8 @@ void cutlass_chunk_prefill_impl( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, diff --git a/csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp b/csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp index 4939c631c..d7d157876 100644 --- a/csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp +++ b/csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp @@ -131,6 +131,9 @@ struct FMHAFwdMainloop< using TensorK = TensorK_; using TensorV = TensorV_; + using ElementQ = typename TensorQ::engine_type::value_type; + using ElementK = typename TensorK::engine_type::value_type; + using TensorQ2D = decltype(TensorQ_{}(append>(make_coord(_, _), 0))); using TensorK2D = @@ -178,6 +181,8 @@ struct FMHAFwdMainloop< using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); using ElementA = typename TiledMMAPV::ValTypeD; + static constexpr bool Fp8KV = + is_any_of_v; static constexpr bool CausalMask = CausalMask_; static constexpr bool LocalMask = LocalMask_; static constexpr bool PagedKV = PagedKV_; @@ -185,6 +190,8 @@ struct FMHAFwdMainloop< // User-facing arguments struct Arguments { ElementS const scale; + ElementS const scale_k; + ElementS const scale_v; // Paged KV Cache int* ptr_page_table; @@ -215,6 +222,8 @@ struct FMHAFwdMainloop< ElementS val = args.scale * static_cast(kLog2e); return Params{ val, + args.scale_k, + args.scale_v, args.ptr_page_table, args.page_size, args.max_pages_per_seq, @@ -375,6 +384,12 @@ struct FMHAFwdMainloop< reorder(tQrQ, tSrQ); reorder(tKrK, tSrK); + if constexpr (Fp8KV) { + for (int i = 0; i < tSrK.size(); ++i) { + tSrK(i) = static_cast( + params.scale_k * static_cast(tSrK(i))); + } + } cute::gemm(mma_qk, tSrQ, tSrK, tSrS); } @@ -440,6 +455,13 @@ struct FMHAFwdMainloop< for (int VV = 0; VV < VTiles; VV++) { copy(copy_v, tVgV_cache(_, _, _, VV), tVrV); reorder(tVrV, tArV); + if constexpr (Fp8KV) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tArV.size(); ++i) { + tArV(i) = static_cast( + params.scale_v * static_cast(tArV(i))); + } + } cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); } diff --git a/csrc/xpu/attn/xe_2/fmha_utils.hpp b/csrc/xpu/attn/xe_2/fmha_utils.hpp index 3b2f6d8ff..4cb09487d 100644 --- a/csrc/xpu/attn/xe_2/fmha_utils.hpp +++ b/csrc/xpu/attn/xe_2/fmha_utils.hpp @@ -8,22 +8,43 @@ #define HEAD_SIZE_LIMIT_3 192 #define HEAD_SIZE_LIMIT_4 256 -enum class CutlassType { - half, - bfloat16, +enum class CutlassDType { half, bfloat16, float8_e4m3, float8_e5m2 }; + +// Struct to carry separate Q and K dtypes without breaking existing API +struct CutlassQKType { + CutlassDType q_type; + CutlassDType k_type; + + // Convenience: construct with identical types + explicit CutlassQKType(CutlassDType t) : q_type(t), k_type(t) {} + CutlassQKType(CutlassDType q_t, CutlassDType k_t) + : q_type(q_t), k_type(k_t) {} }; -inline CutlassType aten_to_Cutlass_dtype(const at::Tensor& input) { - CutlassType cuType; - if (input.scalar_type() == torch::kHalf) { - cuType = CutlassType::half; - } else if (input.scalar_type() == torch::kBFloat16) { - cuType = CutlassType::bfloat16; - } else { - TORCH_INTERNAL_ASSERT( - false, "Current cutlass kernel only support half/bf16 data type."); +inline CutlassDType aten_to_dtype(const at::ScalarType st) { + if (st == torch::kHalf) { + return CutlassDType::half; + } else if (st == torch::kBFloat16) { + return CutlassDType::bfloat16; + } else if (st == torch::kFloat8_e4m3fn) { + return CutlassDType::float8_e4m3; + } else if (st == torch::kFloat8_e5m2) { + return CutlassDType::float8_e5m2; } - return cuType; + TORCH_INTERNAL_ASSERT( + false, + "Unsupported dtype: only half/bfloat16/float8_e4m3/float8_e5m2 supported " + "for Q/K."); +} + +inline CutlassDType aten_to_dtype(const at::Tensor& t) { + return aten_to_dtype(t.scalar_type()); +} + +// Helper to build Q/K dtype pair from tensors +inline CutlassQKType +aten_to_Cutlass_qk_dtype(const at::Tensor& q, const at::Tensor& k) { + return CutlassQKType(aten_to_dtype(q), aten_to_dtype(k)); } using namespace cute; diff --git a/csrc/xpu/attn/xe_2/fmha_xe2.cpp b/csrc/xpu/attn/xe_2/fmha_xe2.cpp index c1c951c35..eebaee004 100644 --- a/csrc/xpu/attn/xe_2/fmha_xe2.cpp +++ b/csrc/xpu/attn/xe_2/fmha_xe2.cpp @@ -15,6 +15,8 @@ void cutlass_chunk_prefill_xe2( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, @@ -35,6 +37,8 @@ void cutlass_chunk_prefill_xe2( cu_seqlens_k, max_seqlen_q, max_seqlen_k, + k_scale, + v_scale, sm_scale, sm_sink_, window_size_left, @@ -57,6 +61,8 @@ void cutlass_chunk_prefill_impl( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, @@ -118,6 +124,8 @@ void cutlass_chunk_prefill_impl( max_seqlen_k, total_seqlen_q, total_seqlen_k, + k_scale, + v_scale, static_cast(sm_scale), is_sink ? sm_sink_.value().data_ptr() : nullptr, batch_size, @@ -134,7 +142,7 @@ void cutlass_chunk_prefill_impl( is_local, is_sink}; - CutlassType cuType = aten_to_Cutlass_dtype(query); + CutlassQKType cuQKType = aten_to_Cutlass_qk_dtype(query, key_cache); static constexpr int max_head_size = 256; TORCH_CHECK( @@ -144,19 +152,19 @@ void cutlass_chunk_prefill_impl( if (args.head_size <= HEAD_SIZE_LIMIT_0) { policy_dispatch_func( - queue, cuType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_1) { policy_dispatch_func( - queue, cuType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_2) { policy_dispatch_func( - queue, cuType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_3) { policy_dispatch_func( - queue, cuType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); } else if (args.head_size <= HEAD_SIZE_LIMIT_4) { policy_dispatch_func( - queue, cuType, args, is_paged, is_causal, is_local, is_sink); + queue, cuQKType, args, is_paged, is_causal, is_local, is_sink); } else { TORCH_CHECK(false, "Unsupported head size for fmha"); } diff --git a/csrc/xpu/attn/xe_2/fmha_xe2.h b/csrc/xpu/attn/xe_2/fmha_xe2.h index 099ecc17e..ed2a42e1d 100644 --- a/csrc/xpu/attn/xe_2/fmha_xe2.h +++ b/csrc/xpu/attn/xe_2/fmha_xe2.h @@ -11,6 +11,8 @@ void cutlass_chunk_prefill_xe2( const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, + float k_scale, + float v_scale, double sm_scale, std::optional& sm_sink_, int window_size_left, diff --git a/csrc/xpu/attn/xe_2/paged_decode.hpp b/csrc/xpu/attn/xe_2/paged_decode.hpp index 6a6698d64..349d81622 100644 --- a/csrc/xpu/attn/xe_2/paged_decode.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode.hpp @@ -419,9 +419,9 @@ struct PagedDecodeConfig { // Template function for explicit instantiation template void decode_policy_dispatch_impl( - sycl::queue& queue, CutlassType cuType, const paged_decode_args_t& args) { + sycl::queue& queue, CutlassDType cuType, const paged_decode_args_t& args) { const int PipelineStages = 1; - if (cuType == CutlassType::half) { + if (cuType == CutlassDType::half) { return PagedDecodeConfig< typename decode_policy::ShapeQK, typename decode_policy::ShapePV, diff --git a/csrc/xpu/attn/xe_2/paged_decode_extern.hpp b/csrc/xpu/attn/xe_2/paged_decode_extern.hpp index 5cf330011..22448b1a4 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_extern.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode_extern.hpp @@ -32,7 +32,7 @@ extern template void \ decode_policy_dispatch_impl( \ sycl::queue & queue, \ - CutlassType cuType, \ + CutlassDType cuType, \ const paged_decode_args_t& args); // Generate all 8 bool combinations for a given policy using nested macros diff --git a/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in b/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in index e5b209621..feb82f968 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in +++ b/csrc/xpu/attn/xe_2/paged_decode_kernel_template.cpp.in @@ -29,7 +29,7 @@ using namespace cute; static_cast(IMPL_KISLOCAL), \ static_cast(IMPL_KISSINK)>( \ sycl::queue & queue, \ - CutlassType cuType, \ + CutlassDType cuType, \ const paged_decode_args_t& args); INSTANTIATE_KERNEL() diff --git a/csrc/xpu/attn/xe_2/paged_decode_utils.hpp b/csrc/xpu/attn/xe_2/paged_decode_utils.hpp index 98ebd0289..5758097a8 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_utils.hpp +++ b/csrc/xpu/attn/xe_2/paged_decode_utils.hpp @@ -5,14 +5,14 @@ using namespace cute; // Runtime dispatcher helper template void decode_policy_dispatch_func( - sycl::queue& queue, CutlassType cuType, const paged_decode_args_t& args) { + sycl::queue& queue, CutlassDType cuType, const paged_decode_args_t& args) { decode_policy_dispatch_impl(queue, cuType, args); } template void decode_policy_dispatch_func( sycl::queue& queue, - CutlassType cuType, + CutlassDType cuType, const paged_decode_args_t& args, bool b, Ts... ts) { @@ -29,7 +29,7 @@ template inline void dispatch_by_head_size( const int head_case, sycl::queue& queue, - CutlassType cuType, + CutlassDType cuType, const paged_decode_args_t& args) { switch (head_case) { case 0: diff --git a/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp b/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp index bf4b5bb7f..4ba39a477 100644 --- a/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp +++ b/csrc/xpu/attn/xe_2/paged_decode_xe2.cpp @@ -150,7 +150,7 @@ void cutlass_paged_decode_impl( is_sink, num_kv_splits}; - CutlassType cuType = aten_to_Cutlass_dtype(query); + CutlassDType cuType = aten_to_dtype(query); static constexpr int max_head_size = 256; TORCH_CHECK( diff --git a/tests/flash_attn/test_flash_attn_varlen_func.py b/tests/flash_attn/test_flash_attn_varlen_func.py index 3ccf4d066..031229a2a 100644 --- a/tests/flash_attn/test_flash_attn_varlen_func.py +++ b/tests/flash_attn/test_flash_attn_varlen_func.py @@ -22,6 +22,7 @@ SINK = [False, True] CASUAL = [False, True] PAGED = [False, True] +FP8KV = [torch.float8_e5m2, torch.float8_e4m3fn, None] def ref_paged_attn(query: torch.Tensor, @@ -36,7 +37,8 @@ def ref_paged_attn(query: torch.Tensor, soft_cap: Optional[float] = None, is_paged: Optional[bool] = True, casual: Optional[bool] = False, - sink: Optional[torch.Tensor] = None) -> torch.Tensor: + sink: Optional[torch.Tensor] = None, + is_fp8kv: bool = False) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() if is_paged: @@ -44,6 +46,10 @@ def ref_paged_attn(query: torch.Tensor, else: _, num_kv_heads, head_size = key_cache.shape + if is_fp8kv: + key_cache = key_cache.to(query.dtype) + value_cache = value_cache.to(query.dtype) + outputs: list[torch.Tensor] = [] start_idx = 0 start_idx_kv = 0 @@ -131,6 +137,7 @@ def ref_paged_attn(query: torch.Tensor, @pytest.mark.parametrize("is_sink", SINK) @pytest.mark.parametrize("is_casual", CASUAL) @pytest.mark.parametrize("is_paged", PAGED) +@pytest.mark.parametrize("fp8_dtype", FP8KV) @torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: list[tuple[int, int]], @@ -146,6 +153,7 @@ def test_varlen_with_paged_kv( is_sink: bool, is_casual: bool, is_paged: bool, + fp8_dtype: Optional[torch.dtype], ) -> None: torch.set_default_device("xpu") torch.xpu.set_device("xpu:0") @@ -222,6 +230,11 @@ def test_varlen_with_paged_kv( q_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 k_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 v_descale = torch.ones(scale_shape, dtype=torch.float32) #noqa: F841 + is_fp8kv = False + if fp8_dtype is not None: + is_fp8kv = True + maybe_quantized_key_cache = key_cache.to(fp8_dtype) + maybe_quantized_value_cache = value_cache.to(fp8_dtype) if is_paged: output = flash_attn_varlen_func(maybe_quantized_query, @@ -251,8 +264,8 @@ def test_varlen_with_paged_kv( s_aux=sink) ref_output = ref_paged_attn(query=query, - key_cache=key_cache, - value_cache=value_cache, + key_cache=maybe_quantized_key_cache, + value_cache=maybe_quantized_value_cache, query_lens=query_lens, kv_lens=kv_lens, block_tables=block_tables, @@ -261,14 +274,18 @@ def test_varlen_with_paged_kv( is_paged=is_paged, sink=sink, window_size_left=window_size[0], - window_size_right=window_size[1]) + window_size_right=window_size[1], + is_fp8kv=is_fp8kv) atol, rtol = 1e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 if window_size[0] != -1 or window_size[1] != -1: atol, rtol = 1.5e-2, 1.5e-2 + if fp8_dtype is not None: + atol, rtol = 1.5e-2, 1.5e-2 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - ref_output))}" + torch.xpu.empty_cache() @pytest.mark.parametrize("seq_lens", @@ -386,3 +403,4 @@ def test_decode_with_paged_kv( atol, rtol = 1.5e-1, 1.5e-1 torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - ref_output))}" + torch.xpu.empty_cache() diff --git a/vllm_xpu_kernels/flash_attn_interface.py b/vllm_xpu_kernels/flash_attn_interface.py index 7732d7c02..14c1c0477 100644 --- a/vllm_xpu_kernels/flash_attn_interface.py +++ b/vllm_xpu_kernels/flash_attn_interface.py @@ -63,6 +63,10 @@ def flash_attn_varlen_func( if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) + if k_descale is None: + k_descale = 1.0 + if v_descale is None: + v_descale = 1.0 # custom op does not support non-tuple input real_window_size: tuple[int, int] if window_size is None: @@ -82,6 +86,16 @@ def flash_attn_varlen_func( "k_descale, v_descale") if num_splits > 1: raise NotImplementedError("FA2 does not support num_splits > 1") + if q_descale is not None: + raise NotImplementedError("FA2 does not support q_descale") + if scheduler_metadata is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata") + if (k_descale is not None + and v_descale is None) or (k_descale is None + and v_descale is not None): + raise NotImplementedError( + "FA2 only supports both KV cache descaled") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( q, k, @@ -98,6 +112,8 @@ def flash_attn_varlen_func( max_seqlen_q, max_seqlen_k, dropout_p, + k_descale, + v_descale, softmax_scale, s_aux, False,