diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 083494f5b0c..c4aa1a015a0 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -4,52 +4,37 @@ #include "flash_common.hpp" -#include "fmha_bwd.hpp" +#include "mha_bwd.h" #include "mask.hpp" -fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_dropout, - bool enable_alibi, - bool deterministic) -{ - return fmha_bwd_traits{head_size, - head_size, - dtype, - false, // is_group_mode - mask.type, - enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - false, // has_dbias - has_dropout, - false, // s_randval - deterministic}; -} - -fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, - // sizes - const int b, - const int seqlen_q, - const int seqlen_k, - const int h, - const int h_k, - const int hdim, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - std::optional &alibi_slopes_, - const at::Tensor out, - const at::Tensor softmax_lse, - const at::Tensor dout, - at::Tensor dq_acc, - at::Tensor d, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, - float softmax_scale, - float p_dropout, - std::pair drop_seed_offset) +aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, + std::string dtype, + bool enable_alibi, + bool has_dropout, + bool deterministic, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + std::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_q = q.stride(0); @@ -80,9 +65,6 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_do = dout.stride(1); ck_tile::index_t nhead_stride_do = dout.stride(2); - // d: (batch_size, nheads, seqlen_q) - // CK assume d share the same stride with lse - // dq: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_dq = dq.stride(0); ck_tile::index_t stride_dq = dq.stride(1); @@ -115,85 +97,95 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); - // alibi_slopes:(batch_size, nheads) or (nhead) stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - return fmha_bwd_args{q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - alibi_slopes_ptr, // bias - out.data_ptr(), - softmax_lse.data_ptr(), - dout.data_ptr(), - d.data_ptr(), - nullptr, // rand_val - dq.data_ptr(), - dk.data_ptr(), - dv.data_ptr(), - nullptr, // dbias - dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q_ptr - nullptr, // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr - seqlen_q, - seqlen_k, - b, - seqlen_q, // max_seqlen_q - seqlen_k, // max_seqlen_k - hdim, // hdim_q - hdim, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, - stride_q, - stride_k, - stride_v, - stride_alibi_slopes, - stride_o, - 0, // stride_randval - stride_do, - stride_dq_acc, - stride_dq, - stride_dk, - stride_dv, - 0, // stride_dbias, FA without bias - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - 0, // nhead_stride_bias, FA without bias - nhead_stride_o, - 0, // nhead_stride_randval - nhead_stride_do, - nhead_stride_lse, - nhead_stride_dq_acc, - nhead_stride_dq, - nhead_stride_dk, - nhead_stride_dv, - 0, // nhead_stride_dbias, FA without dbias - batch_stride_q, - batch_stride_k, - batch_stride_v, - 0 , // batch_stride_bias, FA without bias - batch_stride_o, - 0, // batch_stride_randval - batch_stride_do, - batch_stride_lse, - batch_stride_dq_acc, - batch_stride_dq, - batch_stride_dk, - batch_stride_dv, - 0 , // batch_stride_dbias, FA without dbias - split_stride_dq_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_dropout, - p_undrop, - drop_seed_offset}; + return aiter::mha_bwd_args{false, // use_asm_v3 + false, // v3_atomic_fp32 + 1, // v3_bf16_cvt + false, // v3_api_check + hdim, // hdim_q + hdim, // hdim_v + dtype, + false, // is_group_mode + static_cast(mask.type), + static_cast(enable_alibi ? bias_enum::alibi : bias_enum::no_bias), + false, // has_dbias + has_dropout, + false, // is_store_randval + deterministic, + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + seqlen_k, // max_seqlen_k + h, // nhead_q + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + static_cast(nhead_stride_dq_acc), + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias + batch_stride_o, + 0, // batch_stride_randval + batch_stride_do, + batch_stride_lse, + static_cast(batch_stride_dq_acc), + batch_stride_dq, + batch_stride_dk, + batch_stride_dv, + 0, // batch_stride_dbias + split_stride_dq_acc, + mask.left, + mask.right, + p_dropout, + p_undrop, + drop_seed_offset}; } std::vector @@ -362,12 +354,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic); - auto args = get_ck_fmha_bwd_args( mask, + q_dtype_str, + alibi_slopes_.has_value(), + is_dropout, + deterministic, batch_size, seqlen_q, seqlen_k, @@ -390,7 +383,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num p_dropout, drop_seed_offset); - float t = fmha_bwd(traits, args, stream_config); + float t = aiter::mha_bwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 0229e777cd5..51ce81deb3d 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -4,50 +4,32 @@ #include "flash_common.hpp" -#include "fmha_fwd.hpp" +#include "mha_fwd.h" #include "mask.hpp" -fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_dropout, - bool has_lse, - bool enable_alibi) -{ - return fmha_fwd_traits{head_size, - head_size, - dtype, - false, // is_group_mode - true, // is_v_rowmajor - false, // has_logits_soft_cap - mask.type, - enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - has_lse, - has_dropout, - quant_scale_enum::no_scale}; // qscale_type -} - -fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, - bool has_dropout_randval, - const mask_info &mask, - // sizes - const int b, - const int seqlen_q, - const int seqlen_k, - const int h, - const int h_k, - const int d, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - std::optional &alibi_slopes_, - at::Tensor out, - at::Tensor softmax_lse, - at::Tensor dropout_randval, - float softmax_scale, - float p_dropout, - std::pair drop_seed_offset) +aiter::mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + std::string dtype, + bool enable_alibi, + // sizes + const int b, + const int seqlen_q, + const int seqlen_k, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + std::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) @@ -91,59 +73,80 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - return fmha_fwd_args{q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - alibi_slopes_ptr, // bias - nullptr, // q_descale_ptr - nullptr, // k_descale_ptr - nullptr, // v_descale_ptr - has_dropout_randval ? dropout_randval.data_ptr() : nullptr, - has_lse ? softmax_lse.data_ptr() : nullptr, - out.data_ptr(), - nullptr, // seqstart_q_ptr - nullptr, // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr - seqlen_q, - seqlen_k, - b, - seqlen_q, // max_seqlen_q - d, // hdim_q - d, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, // scale_s - 0.0f, // logits_soft_cap - stride_q, - stride_k, - stride_v, - stride_alibi_slopes, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - 0, // nhead_stride_bias, FA without bias - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - 0, // batch_stride_bias, FA without bias - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - mask.left, - mask.right, - static_cast(mask.type), - 0, // min_seqlen_q - p_dropout, - has_dropout_randval, - drop_seed_offset}; + return aiter::mha_fwd_args{false, // use_asm_v3 + false, // v3_api_check + 1, // how_v3_bf16_cvt + dtype, + false, // is_group_mode + static_cast(enable_alibi ? bias_enum::alibi : bias_enum::no_bias), + has_lse, + static_cast(quant_scale_enum::no_scale), + false, // has_sink + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr + nullptr, // sink_ptr + seqlen_q, + seqlen_k, + b, + seqlen_q, // max_seqlen_q + d, // hdim_q + d, // hdim_v + h, // nhead_q + h_k, // nhead_k + softmax_scale, // scale_s + 0.0f, // logits_soft_cap + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + 0, // nhead_stride_q_descale + 0, // nhead_stride_k_descale + 0, // nhead_stride_v_descale + batch_stride_q, + batch_stride_k, + batch_stride_v, + 0, // batch_stride_bias + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + 0, // batch_stride_q_descale + 0, // batch_stride_k_descale + 0, // batch_stride_v_descale + mask.left, + mask.right, + 0, // sink_size + static_cast(mask.type), + 0, // min_seqlen_q + p_dropout, + has_dropout_randval, + drop_seed_offset, + 128, // block_scale_size_q + 128}; // block_scale_size_kv } std::vector @@ -283,20 +286,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num #endif ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_fwd_traits( - mask, - q_dtype_str, - head_size, - has_dropout, - has_lse, - alibi_slopes_.has_value()); - auto args = get_ck_fmha_fwd_args( has_lse, return_dropout_randval, mask, + q_dtype_str, + alibi_slopes_.has_value(), batch_size, seqlen_q, seqlen_k, @@ -314,7 +310,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num p_dropout, drop_seed_offset); - float t = fmha_fwd(traits, args, stream_config); + float t = aiter::mha_fwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } else { diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 3cd01c32d48..93f98e2488e 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -4,109 +4,80 @@ #include "flash_common.hpp" -#include "fmha_bwd.hpp" +#include "mha_bwd.h" #include "mask.hpp" -fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_dropout, - bool enable_alibi, - bool deterministic) -{ - return fmha_bwd_traits{head_size, - head_size, - dtype, - true, // is_group_mode - mask.type, - enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - false, // has_dbias - has_dropout, - false, // s_randval - deterministic}; -} - -fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, - // sizes - const int b, - const int max_seqlen_q, - const int max_seqlen_k, - const int h, - const int h_k, - const int hdim, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, - std::optional &alibi_slopes_, - const at::Tensor out, - const at::Tensor softmax_lse, - const at::Tensor dout, - at::Tensor dq_acc, - at::Tensor d, - at::Tensor dq, - at::Tensor dk, - at::Tensor dv, - float softmax_scale, - float p_dropout, - std::pair drop_seed_offset) +aiter::mha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, + std::string dtype, + bool enable_alibi, + bool has_dropout, + bool deterministic, + // sizes + const int b, + const int max_seqlen_q, + const int max_seqlen_k, + const int h, + const int h_k, + const int hdim, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + std::optional &alibi_slopes_, + const at::Tensor out, + const at::Tensor softmax_lse, + const at::Tensor dout, + at::Tensor dq_acc, + at::Tensor d, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) { ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); // q: (total_q, nheads, hdim) - ck_tile::index_t batch_stride_q = 0; ck_tile::index_t stride_q = q.stride(0); ck_tile::index_t nhead_stride_q = q.stride(1); // k: (total_k, nheads_k, hdim) - ck_tile::index_t batch_stride_k = 0; ck_tile::index_t stride_k = k.stride(0); ck_tile::index_t nhead_stride_k = k.stride(1); // v: (total_k, nheads_k, hdim) - ck_tile::index_t batch_stride_v = 0; ck_tile::index_t stride_v = v.stride(0); ck_tile::index_t nhead_stride_v = v.stride(1); // o: (total_q, nheads, hdim) - ck_tile::index_t batch_stride_o = 0; ck_tile::index_t stride_o = out.stride(0); ck_tile::index_t nhead_stride_o = out.stride(1); // lse: (nheads, total_q) - ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0); // do: (total_q, nheads, hdim) - ck_tile::index_t batch_stride_do = 0; ck_tile::index_t stride_do = dout.stride(0); ck_tile::index_t nhead_stride_do = dout.stride(1); - // d: (batch_size, nheads, max_seqlen_q) - // CK assume d share the same stride with lse - // dq: (total_q, nheads, hdim) - ck_tile::index_t batch_stride_dq = 0; ck_tile::index_t stride_dq = dq.stride(0); ck_tile::index_t nhead_stride_dq = dq.stride(1); - // dk_expanded: (total_k, nheads, hdim) - ck_tile::index_t batch_stride_dk = 0; ck_tile::index_t stride_dk = dk.stride(0); ck_tile::index_t nhead_stride_dk = dk.stride(1); // dv_expanded: (total_k, nheads, hdim) - ck_tile::index_t batch_stride_dv = 0; ck_tile::index_t stride_dv = dv.stride(0); ck_tile::index_t nhead_stride_dv = dv.stride(1); // dq_acc: (split, total_q, nheads, hdim) ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); - ck_tile::index_t batch_stride_dq_acc = 0; ck_tile::index_t stride_dq_acc = dq_acc.stride(1); ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); @@ -121,85 +92,95 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h})); alibi_slopes_ptr = alibi_slopes.data_ptr(); - // alibi_slopes:(batch_size, nheads) or (nhead) stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - return fmha_bwd_args{q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - alibi_slopes_ptr, // bias - out.data_ptr(), - softmax_lse.data_ptr(), - dout.data_ptr(), - d.data_ptr(), - nullptr, // rand_val - dq.data_ptr(), - dk.data_ptr(), - dv.data_ptr(), - nullptr, // dbias - dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q_ptr - seqlens_k.data_ptr(), // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_k_ptr - total_q, - total_k, - b, - max_seqlen_q, // max_seqlen_q - max_seqlen_k, // max_seqlen_k - hdim, // hdim_q - hdim, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, - stride_q, - stride_k, - stride_v, - stride_alibi_slopes, - stride_o, - 0, // stride_randval - stride_do, - stride_dq_acc, - stride_dq, - stride_dk, - stride_dv, - 0, // stride_dbias, FA without bias - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - 0, // nhead_stride_bias, FA without bias - nhead_stride_o, - 0, // nhead_stride_randval - nhead_stride_do, - nhead_stride_lse, - nhead_stride_dq_acc, - nhead_stride_dq, - nhead_stride_dk, - nhead_stride_dv, - 0, // nhead_stride_dbias, FA without dbias - batch_stride_q, - batch_stride_k, - batch_stride_v, - 0 , // batch_stride_bias, FA without bias - batch_stride_o, - 0, // batch_stride_randval - batch_stride_do, - batch_stride_lse, - batch_stride_dq_acc, - batch_stride_dq, - batch_stride_dk, - batch_stride_dv, - 0 , // batch_stride_dbias, FA without dbias - split_stride_dq_acc, - mask.left, - mask.right, - static_cast(mask.type), - p_dropout, - p_undrop, - drop_seed_offset}; + return aiter::mha_bwd_args{false, // use_asm_v3 + false, // v3_atomic_fp32 + 1, // v3_bf16_cvt + false, // v3_api_check + hdim, // hdim_q + hdim, // hdim_v + dtype, + true, // is_group_mode + static_cast(mask.type), + static_cast(enable_alibi ? bias_enum::alibi : bias_enum::no_bias), + false, // has_dbias + has_dropout, + false, // is_store_randval + deterministic, + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + out.data_ptr(), + softmax_lse.data_ptr(), + dout.data_ptr(), + d.data_ptr(), + nullptr, // rand_val + dq.data_ptr(), + dk.data_ptr(), + dv.data_ptr(), + nullptr, // dbias + dq_acc.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr + total_q, + total_k, + b, + max_seqlen_q, + max_seqlen_k, + h, // nhead_q + h_k, // nhead_k + softmax_scale, + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_o, + 0, // stride_randval + stride_do, + stride_dq_acc, + stride_dq, + stride_dk, + stride_dv, + 0, // stride_dbias + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias + nhead_stride_o, + 0, // nhead_stride_randval + nhead_stride_do, + nhead_stride_lse, + static_cast(nhead_stride_dq_acc), + nhead_stride_dq, + nhead_stride_dk, + nhead_stride_dv, + 0, // nhead_stride_dbias + 0, // batch_stride_q + 0, // batch_stride_k + 0, // batch_stride_v + 0, // batch_stride_bias + 0, // batch_stride_o + 0, // batch_stride_randval + 0, // batch_stride_do + 0, // batch_stride_lse + static_cast(0), // batch_stride_dq_acc + 0, // batch_stride_dq + 0, // batch_stride_dk + 0, // batch_stride_dv + 0, // batch_stride_dbias + split_stride_dq_acc, + mask.left, + mask.right, + p_dropout, + p_undrop, + drop_seed_offset}; } std::vector @@ -385,12 +366,13 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; - auto traits = - get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic); - auto args = get_ck_fmha_varlen_bwd_args( mask, + q_dtype_str, + alibi_slopes_.has_value(), + is_dropout, + deterministic, batch_size, max_seqlen_q, max_seqlen_k, @@ -415,7 +397,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads p_dropout, drop_seed_offset); - float t = fmha_bwd(traits, args, stream_config); + float t = aiter::mha_bwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. @@ -431,4 +413,4 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } return { dq, dk, dv, softmax_d }; -} \ No newline at end of file +} diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 00b0fcd5738..c1f6745af5b 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -4,69 +4,34 @@ #include "flash_common.hpp" +#include "mha_fwd.h" #include "fmha_fwd.hpp" #include "mask.hpp" -fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_dropout, - bool has_lse, - bool enable_alibi) -{ - return fmha_fwd_traits{head_size, - head_size, - dtype, - true, // is_group_mode - true, // is_v_rowmajor - false, // has_logits_soft_cap - mask.type, - enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - has_lse, - has_dropout, - quant_scale_enum::no_scale}; // qscale_type -} - -fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, - std::string dtype, - int head_size, - bool has_lse, - bool enable_alibi) -{ - return fmha_fwd_splitkv_traits{head_size, - head_size, - dtype, - true, // is_group_mode - true, // is_v_rowmajor - false, // has_logits_soft_cap - mask.type, - enable_alibi ? bias_enum::alibi : bias_enum::no_bias, - has_lse, - false}; // do_fp8_static_quant -} - -fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, - bool has_dropout_randval, - const mask_info &mask, - // sizes - const int b, - const int max_seqlen_q, - const int h, - const int h_k, - const int d, - // device pointers - const at::Tensor q, - const at::Tensor k, - const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, - std::optional &alibi_slopes_, - at::Tensor out, - at::Tensor softmax_lse, - at::Tensor dropout_randval, - float softmax_scale, - float p_dropout, - std::pair drop_seed_offset) +aiter::mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, + bool has_dropout_randval, + const mask_info &mask, + std::string dtype, + bool enable_alibi, + // sizes + const int b, + const int max_seqlen_q, + const int h, + const int h_k, + const int d, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor seqlens_q, + const at::Tensor seqlens_k, + std::optional &alibi_slopes_, + at::Tensor out, + at::Tensor softmax_lse, + at::Tensor dropout_randval, + float softmax_scale, + float p_dropout, + std::pair drop_seed_offset) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) @@ -93,13 +58,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; - ck_tile::index_t batch_stride_q = 0; - ck_tile::index_t batch_stride_k = 0; - ck_tile::index_t batch_stride_v = 0; - ck_tile::index_t batch_stride_o = 0; - ck_tile::index_t batch_stride_lse = 0; - ck_tile::index_t batch_stride_randval = 0; - void *alibi_slopes_ptr = nullptr; ck_tile::index_t stride_alibi_slopes = 0; @@ -112,59 +70,80 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - return fmha_fwd_args{q.data_ptr(), - k.data_ptr(), - v.data_ptr(), - alibi_slopes_ptr, // bias - nullptr, // q_descale_ptr - nullptr, // k_descale_ptr - nullptr, // v_descale_ptr - has_dropout_randval ? dropout_randval.data_ptr() : nullptr, - has_lse ? softmax_lse.data_ptr() : nullptr, - out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q_ptr - seqlens_k.data_ptr(), // seqstart_k_ptr - nullptr, // seqlen_q_ptr - nullptr, // seqlen_k_ptr - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_kv_ptr - total_q, - total_k, - b, - max_seqlen_q, - d, // hdim_q - d, // hdim_v - h, // nhead - h_k, // nhead_k - softmax_scale, // scale_s - 0.0f, // logits_soft_cap - stride_q, - stride_k, - stride_v, - stride_alibi_slopes, - stride_randval, - stride_o, - nhead_stride_q, - nhead_stride_k, - nhead_stride_v, - 0, // nhead_stride_bias, FA without bias - nhead_stride_randval, - nhead_stride_lse, - nhead_stride_o, - batch_stride_q, - batch_stride_k, - batch_stride_v, - 0, // batch_stride_bias, FA without bias - batch_stride_randval, - batch_stride_lse, - batch_stride_o, - mask.left, - mask.right, - static_cast(mask.type), - 0, // min_seqlen_q - p_dropout, - has_dropout_randval, - drop_seed_offset}; + return aiter::mha_fwd_args{false, // use_asm_v3 + false, // v3_api_check + 1, // how_v3_bf16_cvt + dtype, + true, // is_group_mode + static_cast(enable_alibi ? bias_enum::alibi : bias_enum::no_bias), + has_lse, + static_cast(quant_scale_enum::no_scale), + false, // has_sink + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr + has_dropout_randval ? dropout_randval.data_ptr() : nullptr, + has_lse ? softmax_lse.data_ptr() : nullptr, + out.data_ptr(), + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr + nullptr, // sink_ptr + total_q, + total_k, + b, + max_seqlen_q, + d, // hdim_q + d, // hdim_v + h, // nhead_q + h_k, // nhead_k + softmax_scale, // scale_s + 0.0f, // logits_soft_cap + stride_q, + stride_k, + stride_v, + stride_alibi_slopes, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + 0, // nhead_stride_bias + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + 0, // nhead_stride_q_descale + 0, // nhead_stride_k_descale + 0, // nhead_stride_v_descale + 0, // batch_stride_q + 0, // batch_stride_k + 0, // batch_stride_v + 0, // batch_stride_bias + 0, // batch_stride_randval + 0, // batch_stride_lse + 0, // batch_stride_o + 0, // batch_stride_q_descale + 0, // batch_stride_k_descale + 0, // batch_stride_v_descale + mask.left, + mask.right, + 0, // sink_size + static_cast(mask.type), + 0, // min_seqlen_q + p_dropout, + has_dropout_randval, + drop_seed_offset, + 128, // block_scale_size_q + 128}; // block_scale_size_kv } fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, @@ -482,13 +461,17 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si if (paged_KV) { - auto traits = - get_ck_fmha_varlen_fwd_splitkv_traits( - mask, - q_dtype_str, - head_size, - has_lse, - alibi_slopes_.has_value()); + auto splitkv_traits = + fmha_fwd_splitkv_traits{head_size, + head_size, + q_dtype_str, + true, // is_group_mode + true, // is_v_rowmajor + false, // has_logits_soft_cap + mask.type, + alibi_slopes_.has_value() ? bias_enum::alibi : bias_enum::no_bias, + has_lse, + false}; // do_fp8_static_quant auto args = get_ck_fmha_varlen_fwd_splitkv_args( @@ -514,27 +497,20 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si softmax_lse_accum, out_accum); - float t = fmha_fwd_splitkv(traits, args, stream_config); + float t = fmha_fwd_splitkv(splitkv_traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd_splitkv"); } else { auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); - auto traits = - get_ck_fmha_varlen_fwd_traits( - mask, - q_dtype_str, - head_size, - has_dropout, - has_lse, - alibi_slopes_.has_value()); - auto args = get_ck_fmha_varlen_fwd_args( has_lse, return_dropout_randval, mask, + q_dtype_str, + alibi_slopes_.has_value(), batch_size, max_seqlen_q, num_heads, @@ -553,7 +529,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si p_dropout, drop_seed_offset); - float t = fmha_fwd(traits, args, stream_config); + float t = aiter::mha_fwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } } diff --git a/setup.py b/setup.py index 3bf4a904b00..4ba725cb21f 100644 --- a/setup.py +++ b/setup.py @@ -213,24 +213,17 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. if IS_ROCM: + if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", "--recursive", "third_party/aiter"], check=True) + else: + assert os.path.isdir("third_party/aiter"), ( + "third_party/aiter is missing, please use source distribution or git clone" + ) if ROCM_BACKEND == "triton": - if os.path.isdir(".git"): - subprocess.run(["git", "submodule", "update", "--init", "third_party/aiter"], check=True) - else: - assert os.path.isdir("third_party/aiter"), ( - "third_party/aiter is missing, please use source distribution or git clone" - ) subprocess.run( [sys.executable, "-m", "pip", "install", "--no-build-isolation", "third_party/aiter"], check=True, ) - elif ROCM_BACKEND == "ck": - if os.path.isdir(".git"): - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) - else: - assert os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py"), ( - "csrc/composable_kernel is missing, please use source distribution or git clone" - ) else: # CUDA: cutlass submodule if os.path.isdir(".git"): @@ -390,7 +383,8 @@ def validate_and_update_archs(archs): # Skips CK C++ extension compilation if using Triton Backend if ROCM_BACKEND == "ck": - ck_dir = "csrc/composable_kernel" + aiter_dir = "third_party/aiter" + ck_dir = f"{aiter_dir}/3rdparty/composable_kernel" #use codegen get code dispatch if not os.path.exists("./build"): @@ -402,6 +396,38 @@ def validate_and_update_archs(archs): subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2", "--optdim", optdim], check=True) + # Generate stub header for ASM v3 bwd configs (v3 ASM path is disabled, but header is required) + with open("build/asm_fmha_v3_bwd_configs.hpp", "w") as f: + f.write("""// Auto-generated stub - ASM v3 path is disabled +#pragma once +#include +#include + +struct fmha_v3_bwdConfig { + std::string knl_name; + std::string co_name; + std::string arch; + std::string dtype; + int hdim_q; + int hdim_v; + int mask; + int atomic32; + int pssk; + int pddv; + int mode; + int bf16_cvt; + int ts_qo; + int ts; +}; + +using CFG = std::unordered_map; + +static CFG cfg_fmha_bwd_odo = {}; +static CFG cfg_fmha_bwd_dqdkdv = {}; +static CFG cfg_fmha_bwd_dq_convert = {}; +static CFG cfg_fmha_bwd_dq_shuffle = {}; +""") + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 generator_flag = [] @@ -435,6 +461,12 @@ def validate_and_update_archs(archs): f"build/fmha_*wd*.cpp" ) + # Aiter C++ interface sources (already .cu, added to renamed_sources below) + aiter_sources = [ + f"{aiter_dir}/csrc/cpp_itfs/mha_fwd.cu", + f"{aiter_dir}/csrc/cpp_itfs/mha_bwd.cu", + ] + # Check if torch is using hipify v2. Until CK is updated with HIPIFY_V2 macro, # we must replace the incorrect APIs. maybe_hipify_v2_flag = [] @@ -449,7 +481,7 @@ def validate_and_update_archs(archs): "csrc/flash_attn_ck/mha_fwd_kvcache.cu", "csrc/flash_attn_ck/mha_fwd.cu", "csrc/flash_attn_ck/mha_varlen_bwd.cu", - "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu") + aiter_sources cc_flag += ["-O3","-std=c++20", "-DCK_TILE_FMHA_FWD_FAST_EXP2=1", @@ -464,7 +496,8 @@ def validate_and_update_archs(archs): "-DCK_USE_XDL", "-DUSE_PROF_API=1", # "-DFLASHATTENTION_DISABLE_BACKWARD", - "-D__HIP_PLATFORM_HCC__=1"] + "-D__HIP_PLATFORM_HCC__=1", + "-DFAV2_ON=1"] cc_flag += [f"-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT={os.environ.get('CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT', 3)}"] @@ -488,9 +521,11 @@ def validate_and_update_archs(archs): } include_dirs = [ - Path(this_dir) / "csrc" / "composable_kernel" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include", - Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha", + Path(this_dir) / ck_dir / "include", + Path(this_dir) / ck_dir / "library" / "include", + Path(this_dir) / ck_dir / "example" / "ck_tile" / "01_fmha", + Path(this_dir) / aiter_dir / "csrc" / "include", + Path(this_dir) / "build", ] ext_modules.append(