diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 13f6d635653..574c1c121a0 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 13f6d635653bd5ffbfcac8577f1ef09590c23d78 +Subproject commit 574c1c121a0f3c0b44155b2b1987d89d16159b58 diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 083494f5b0c..19a269a0344 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -9,13 +9,25 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, std::string dtype, + int seqlen_q, + int seqlen_k, + int batch, int head_size, + int nhead_q, + int nhead_k, bool has_dropout, bool enable_alibi, bool deterministic) { - return fmha_bwd_traits{head_size, - head_size, + return fmha_bwd_traits{seqlen_q, + seqlen_k, + batch, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + head_size, // hdim_q + head_size, // hdim_k + nhead_q, + nhead_k, dtype, false, // is_group_mode mask.type, @@ -98,11 +110,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(2); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + // dq_acc: (batch_size, nheads, split, seqlen_q, hdim) + ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(0); + ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2); + ck_tile::index_t stride_dq_acc = dq_acc.stride(3); float p_undrop = 1.0 - p_dropout; @@ -222,7 +234,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num #endif if (is_causal) { window_size_right = 0; } - bool is_dropout = p_dropout > 0.0; + const bool is_dropout = p_dropout > 0.0; #ifdef HIPIFY_V2 auto stream = at::cuda::getCurrentCUDAStream().stream(); #else @@ -238,7 +250,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); - std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); @@ -316,19 +328,26 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num dv = torch::empty_like(v); } + const auto traits = get_ck_fmha_bwd_traits( + mask, + q_dtype_str, + seqlen_q, + seqlen_k, + batch_size, + head_size, + num_heads, + num_heads_k, + is_dropout, + alibi_slopes_.has_value(), + deterministic); + fmha_bwd_launcher launcher(traits); + const ck_tile::index_t nsplits = launcher.dq_acc_splits; + at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - - if (!deterministic) { - dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } else { - const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } + at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -362,9 +381,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic); - auto args = get_ck_fmha_bwd_args( mask, diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 0229e777cd5..44f7f4f0d93 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -107,6 +107,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr + nullptr, // seqstart_v_scale_ptr + nullptr, // sink_ptr seqlen_q, seqlen_k, b, @@ -123,6 +127,9 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, stride_alibi_slopes, stride_randval, stride_o, + 0, // stride_q_descale + 0, // stride_k_descale + 0, // stride_v_descale nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -130,6 +137,9 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + 0, // nhead_stride_q_descale + 0, // nhead_stride_k_descale + 0, // nhead_stride_v_descale batch_stride_q, batch_stride_k, batch_stride_v, @@ -137,13 +147,19 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, batch_stride_randval, batch_stride_lse, batch_stride_o, + 0, // batch_stride_q_descale + 0, // batch_stride_k_descale + 0, // batch_stride_v_descale mask.left, mask.right, + 0, // sink_size static_cast(mask.type), 0, // min_seqlen_q p_dropout, has_dropout_randval, - drop_seed_offset}; + drop_seed_offset, + 0, // block_scale_size_q + 0}; // block_scale_size_kv } std::vector diff --git a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp index 27866f1902e..2a478334df1 100644 --- a/csrc/flash_attn_ck/mha_fwd_kvcache.cpp +++ b/csrc/flash_attn_ck/mha_fwd_kvcache.cpp @@ -32,13 +32,14 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask, return fmha_fwd_splitkv_traits{head_size, head_size, dtype, - false, // is_group_mode - true, // is_v_rowmajor - false, // has_logits_soft_cap + false, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, - false}; // do_fp8_static_quant + false, // do_fp8_static_quant + false}; // has_sink } fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b, @@ -177,6 +178,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, args.o_acc_ptr = out_acc.data_ptr(); args.lse_ptr = nullptr; args.o_ptr = out.data_ptr(); + args.sink_ptr = nullptr; if (block_table_.has_value()) { @@ -261,6 +263,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 3cd01c32d48..68618f8cecc 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -9,13 +9,27 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, std::string dtype, + int seqlen_q, + int seqlen_k, + int batch, + int max_seqlen_q, + int max_seqlen_k, int head_size, + int nhead_q, + int nhead_k, bool has_dropout, bool enable_alibi, bool deterministic) { - return fmha_bwd_traits{head_size, - head_size, + return fmha_bwd_traits{seqlen_q, + seqlen_k, + batch, + max_seqlen_q, + max_seqlen_k, + head_size, // hdim_q + head_size, // hdim_k + nhead_q, + nhead_k, dtype, true, // is_group_mode mask.type, @@ -25,7 +39,6 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, false, // s_randval deterministic}; } - fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, // sizes const int b, @@ -104,11 +117,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(0); ck_tile::index_t nhead_stride_dv = dv.stride(1); - // dq_acc: (split, total_q, nheads, hdim) - ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::index_t batch_stride_dq_acc = 0; - ck_tile::index_t stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); + // dq_acc: (nheads, split, total_q, hdim) + ck_tile::long_index_t batch_stride_dq_acc = 0; + ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(0); + ck_tile::index_t split_stride_dq_acc = dq_acc.stride(1); + ck_tile::index_t stride_dq_acc = dq_acc.stride(2); float p_undrop = 1.0 - p_dropout; @@ -233,7 +246,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads #endif if (is_causal) { window_size_right = 0; } - bool is_dropout = p_dropout > 0.0; + const bool is_dropout = p_dropout > 0.0; auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); @@ -247,7 +260,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; + const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); @@ -330,19 +343,28 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads dv = torch::empty_like(v); } + const auto traits = get_ck_fmha_varlen_bwd_traits( + mask, + q_dtype_str, + total_q, + total_k, + batch_size, + max_seqlen_q, + max_seqlen_k, + head_size, + num_heads, + num_heads_k, + is_dropout, + alibi_slopes_.has_value(), + deterministic); + fmha_bwd_launcher launcher(traits); + const ck_tile::index_t nsplits = launcher.dq_acc_splits; + at::cuda::CUDAGuard device_guard{q.device()}; auto opts = q.options(); auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor dq_accum; - - if (!deterministic) { - dq_accum = torch::zeros({1, total_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } else { - const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64; - const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0); - dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size}, opts.dtype(at::kFloat)); - } + at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat)); at::Tensor dk_expanded, dv_expanded; if (num_heads_k != num_heads) { // MQA / GQA @@ -385,9 +407,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic); - auto args = get_ck_fmha_varlen_bwd_args( mask, diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 00b0fcd5738..5bf60a82d7d 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -42,7 +42,8 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m mask.type, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, - false}; // do_fp8_static_quant + false, // do_fp8_static_quant + false}; // has_sink } fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, @@ -128,6 +129,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_kv_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr + nullptr, // seqstart_v_scale_ptr + nullptr, // sink_ptr total_q, total_k, b, @@ -144,6 +149,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, stride_alibi_slopes, stride_randval, stride_o, + 0, // stride_q_descale + 0, // stride_k_descale + 0, // stride_v_descale nhead_stride_q, nhead_stride_k, nhead_stride_v, @@ -151,6 +159,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + 0, // nhead_stride_q_descale + 0, // nhead_stride_k_descale + 0, // nhead_stride_v_descale batch_stride_q, batch_stride_k, batch_stride_v, @@ -158,13 +169,19 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, batch_stride_randval, batch_stride_lse, batch_stride_o, + 0, // batch_stride_q_descale + 0, // batch_stride_k_descale + 0, // batch_stride_v_descale mask.left, mask.right, + 0, // sink_size static_cast(mask.type), 0, // min_seqlen_q p_dropout, has_dropout_randval, - drop_seed_offset}; + drop_seed_offset, + 0, // block_scale_size_q + 0}; // block_scale_size_kv } fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, @@ -210,6 +227,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.o_acc_ptr = out_acc.data_ptr(); args.lse_ptr = nullptr; args.o_ptr = out.data_ptr(); + args.sink_ptr = nullptr; if (block_table_.has_value()) { @@ -293,6 +311,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = 0; args.mask_type = static_cast(mask.type); return args; diff --git a/setup.py b/setup.py index 730a190a876..cad936f822e 100644 --- a/setup.py +++ b/setup.py @@ -437,6 +437,8 @@ def validate_and_update_archs(archs): "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") cc_flag += ["-O3","-std=c++20", + "-Wno-unknown-warning-option", + "-fbracket-depth=1024", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", "-fgpu-flush-denormals-to-zero", "-DCK_ENABLE_BF16",