diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 2c890d47a03..aa66087f6d4 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -533,6 +533,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, + int num_splits, std::optional gen_) { // Otherwise the kernel will be launched from cuda:0 device @@ -706,7 +707,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts); + p_dropout, num_splits, get_num_sm(get_current_device()), opts); } if (leftpad_k_.has_value()) { diff --git a/csrc/flash_attn/flash_api_torch_lib.cpp b/csrc/flash_attn/flash_api_torch_lib.cpp index d1299c54cd9..be2d23c0da7 100644 --- a/csrc/flash_attn/flash_api_torch_lib.cpp +++ b/csrc/flash_attn/flash_api_torch_lib.cpp @@ -35,6 +35,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s int window_size_right, const float softcap, const bool return_softmax, + int num_splits, std::optional gen_); std::vector @@ -109,7 +110,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? leftpad_k, Tensor? block_table, Tensor? alibi_slopes, " "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " "bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, " - "Generator? gen) -> Tensor[]"); + "int num_splits, Generator? gen) -> Tensor[]"); ops.impl("varlen_fwd", torch::kCUDA, make_pytorch_shim(&mha_varlen_fwd)); ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, " diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 96f9335841d..a61d0f9613a 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -271,6 +271,7 @@ def flash_attn_varlen_func( real_window_size[1], softcap, return_softmax_lse and dropout_p > 0, + num_splits, None, ) elif fa_version == 3: