Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ std::vector<at::Tensor> mha_varlen_fwd(
int max_seqlen_q,
int max_seqlen_k,
float p_dropout,
float k_scale,
float v_scale,
float softmax_scale,
std::optional<const at::Tensor>& softmax_sink_,
const bool zero_tensors,
Expand All @@ -32,14 +34,23 @@ std::vector<at::Tensor> mha_varlen_fwd(
const bool return_softmax,
std::optional<at::Generator> gen_) {
auto q_type = q.scalar_type();
auto k_type = k.scalar_type();
TORCH_CHECK(
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
"VLLM Kernel XPU only supports fp16 and bf16 type");

TORCH_CHECK(
k.scalar_type() == q_type, "query and key must have the same dtype");
TORCH_CHECK(
v.scalar_type() == q_type, "query and value must have the same dtype");
v.scalar_type() == k_type, "key and value must have the same dtype");
bool is_fp8kv = false;
if (k_type == at::ScalarType::Float8_e5m2 ||
k_type == at::ScalarType::Float8_e4m3fn) {
is_fp8kv = true;
} else {
TORCH_CHECK(
k.scalar_type() == q_type, "query and key must have the same dtype");
TORCH_CHECK(
v.scalar_type() == q_type, "query and value must have the same dtype");
}

CHECK_DEVICE(q);
CHECK_DEVICE(k);
Expand Down Expand Up @@ -94,7 +105,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
bool is_local = (window_size_left != -1) | (window_size_right != -1);
bool is_sink = softmax_sink_.has_value();

if (max_seqlen_q > 1 || is_local || !is_paged) {
if (max_seqlen_q > 1 || is_local || !is_paged || is_fp8kv) {
at::Tensor seqlens_k = is_paged ? *seqused_k : cu_seqlens_k;

cutlass_chunk_prefill_interface(
Expand All @@ -108,6 +119,8 @@ std::vector<at::Tensor> mha_varlen_fwd(
seqlens_k,
max_seqlen_q,
max_seqlen_k,
k_scale,
v_scale,
softmax_scale,
softmax_sink_,
window_size_left,
Expand Down Expand Up @@ -182,8 +195,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? "
"block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float "
"softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float k_scale, "
"float v_scale, "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can reformat here.

"float softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float "
"softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
Expand Down
4 changes: 4 additions & 0 deletions csrc/xpu/attn/attn_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ void cutlass_chunk_prefill_interface(
const at::Tensor& cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
float k_scale,
float v_scale,
double sm_scale,
std::optional<const at::Tensor>& sm_sink_,
int window_size_left,
Expand All @@ -40,6 +42,8 @@ void cutlass_chunk_prefill_interface(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
k_scale,
v_scale,
sm_scale,
sm_sink_,
window_size_left,
Expand Down
2 changes: 2 additions & 0 deletions csrc/xpu/attn/attn_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ void cutlass_chunk_prefill_interface(
const at::Tensor& cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
float k_scale,
float v_scale,
double sm_scale,
std::optional<const at::Tensor>& sm_sink_,
int window_size_left,
Expand Down
134 changes: 106 additions & 28 deletions csrc/xpu/attn/xe_2/chunk_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ struct chunk_prefill_args_t {
int max_keys;
int total_seqlen_q;
int total_seqlen_k;
float k_scale = 1.0;
float v_scale = 1.0;
float sm_scale;
void* sm_sink;
int batch_size;
Expand Down Expand Up @@ -145,6 +147,8 @@ struct KernelLauncher {
stride_O,
reinterpret_cast<ElementQ*>(args.sm_sink)},
{args.sm_scale,
args.k_scale,
args.v_scale,
static_cast<int*>(args.block_table),
args.block_size,
args.max_blocks_per_seq,
Expand Down Expand Up @@ -315,35 +319,109 @@ struct FMHAConfig {

template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink>
void policy_dispatch_impl(
sycl::queue& queue, CutlassType cuType, const chunk_prefill_args_t& args) {
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args) {
const int PipelineStages = 2;
if (cuType == CutlassType::half) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
half_t,
half_t,
half_t>::kernel_dispatch(queue, args);
if (cuQKType.q_type == CutlassDType::half) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make Q/KV type also template here in the future. I feel current change will increase compile time a little bit.

if (cuQKType.k_type == CutlassDType::half) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
half_t,
half_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e4m3_t,
float_e4m3_t,
half_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
half_t,
float_e5m2_t,
float_e5m2_t,
half_t>::kernel_dispatch(queue, args);
}
} else {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink>::kernel_dispatch(queue, args);
if (cuQKType.k_type == CutlassDType::bfloat16) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
bfloat16_t,
bfloat16_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e4m3) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e4m3_t,
float_e4m3_t,
bfloat16_t>::kernel_dispatch(queue, args);
} else if (cuQKType.k_type == CutlassDType::float8_e5m2) {
return FMHAConfig<
typename chunk_policy::ShapeQK,
typename chunk_policy::ShapePV,
typename chunk_policy::ShapeOut,
typename chunk_policy::SubgroupLayoutQK,
void,
PipelineStages,
Paged,
Causal,
Local,
Sink,
bfloat16_t,
float_e5m2_t,
float_e5m2_t,
bfloat16_t>::kernel_dispatch(queue, args);
}
}
}
2 changes: 1 addition & 1 deletion csrc/xpu/attn/xe_2/chunk_prefill_extern.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
extern template void \
policy_dispatch_impl<POLICY, PAGED, CAUSAL, LOCAL, SINK>( \
sycl::queue & queue, \
CutlassType cuType, \
CutlassQKType & cuQKType, \
const chunk_prefill_args_t& args);

// Generate all 16 bool combinations for a given policy using nested macros
Expand Down
2 changes: 1 addition & 1 deletion csrc/xpu/attn/xe_2/chunk_prefill_kernel_template.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using namespace cute;
static_cast<bool>(IMPL_KISLOCAL), \
static_cast<bool>(IMPL_KISSINK)>( \
sycl::queue & queue, \
CutlassType cuType, \
CutlassQKType& cuQKType, \
const chunk_prefill_args_t& args);

INSTANTIATE_KERNEL()
15 changes: 10 additions & 5 deletions csrc/xpu/attn/xe_2/chunk_prefill_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@ using namespace cute;

template <typename chunk_policy, bool... Bs>
void policy_dispatch_func(
sycl::queue& queue, CutlassType cuType, const chunk_prefill_args_t& args) {
policy_dispatch_impl<chunk_policy, Bs...>(queue, cuType, args);
sycl::queue& queue,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args) {
policy_dispatch_impl<chunk_policy, Bs...>(queue, cuQKType, args);
}

template <typename chunk_policy, bool... Bs, typename... Ts>
void policy_dispatch_func(
sycl::queue& queue,
CutlassType cuType,
CutlassQKType& cuQKType,
const chunk_prefill_args_t& args,
bool b,
Ts... ts) {
if (b) {
policy_dispatch_func<chunk_policy, Bs..., true>(queue, cuType, args, ts...);
policy_dispatch_func<chunk_policy, Bs..., true>(
queue, cuQKType, args, ts...);
} else {
policy_dispatch_func<chunk_policy, Bs..., false>(
queue, cuType, args, ts...);
queue, cuQKType, args, ts...);
}
}

Expand All @@ -34,6 +37,8 @@ void cutlass_chunk_prefill_impl(
const at::Tensor& cu_seqlens_k,
int max_seqlen_q,
int max_seqlen_k,
float k_scale,
float v_scale,
double sm_scale,
std::optional<const at::Tensor>& sm_sink_,
int window_size_left,
Expand Down
22 changes: 22 additions & 0 deletions csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ struct FMHAFwdMainloop<
using TensorK = TensorK_;
using TensorV = TensorV_;

using ElementQ = typename TensorQ::engine_type::value_type;
using ElementK = typename TensorK::engine_type::value_type;

using TensorQ2D =
decltype(TensorQ_{}(append<rank_v<TensorQ_>>(make_coord(_, _), 0)));
using TensorK2D =
Expand Down Expand Up @@ -178,13 +181,17 @@ struct FMHAFwdMainloop<
using FragARow = decltype(reduce<1>(FragA{}, sycl::plus<void>{}));
using ElementA = typename TiledMMAPV::ValTypeD;

static constexpr bool Fp8KV =
is_any_of_v<ElementK, float_e5m2_t, float_e4m3_t>;
static constexpr bool CausalMask = CausalMask_;
static constexpr bool LocalMask = LocalMask_;
static constexpr bool PagedKV = PagedKV_;

// User-facing arguments
struct Arguments {
ElementS const scale;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's scale here for?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

softmax scale

ElementS const scale_k;
ElementS const scale_v;

// Paged KV Cache
int* ptr_page_table;
Expand Down Expand Up @@ -215,6 +222,8 @@ struct FMHAFwdMainloop<
ElementS val = args.scale * static_cast<ElementS>(kLog2e);
return Params{
val,
args.scale_k,
args.scale_v,
args.ptr_page_table,
args.page_size,
args.max_pages_per_seq,
Expand Down Expand Up @@ -375,6 +384,12 @@ struct FMHAFwdMainloop<

reorder(tQrQ, tSrQ);
reorder(tKrK, tSrK);
if constexpr (Fp8KV) {
for (int i = 0; i < tSrK.size(); ++i) {
tSrK(i) = static_cast<ElementQ>(
params.scale_k * static_cast<float>(tSrK(i)));
}
}
cute::gemm(mma_qk, tSrQ, tSrK, tSrS);
}

Expand Down Expand Up @@ -440,6 +455,13 @@ struct FMHAFwdMainloop<
for (int VV = 0; VV < VTiles; VV++) {
copy(copy_v, tVgV_cache(_, _, _, VV), tVrV);
reorder(tVrV, tArV);
if constexpr (Fp8KV) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < tArV.size(); ++i) {
tArV(i) = static_cast<ElementQ>(
params.scale_v * static_cast<float>(tArV(i)));
}
}
cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV));
}

Expand Down
Loading