diff --git a/csrc/composable_kernel b/csrc/composable_kernel index e8709c24f40..46f1d740f03 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit e8709c24f403173ad21a2da907d1347957e324fb +Subproject commit 46f1d740f03d11bc2a78fce60a95cd0933b9dd4d diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index bb879453680..6f85d37e5dc 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -135,7 +135,10 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, dq_acc.data_ptr(), // dq_acc nullptr, // seqstart_q nullptr, // seqstart_k + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 4d7d5bd655e..c2156f951d1 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, @@ -95,12 +95,18 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // seqstart_q - nullptr, // seqstart_k - nullptr, + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, @@ -110,8 +116,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o 0.0f, // logits_soft_cap stride_q, stride_k, diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index bfeb3b770d0..14e04af42f4 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -141,7 +141,10 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, dq_acc.data_ptr(), // dq_acc seqlens_q.data_ptr(), // seqstart_q seqlens_k.data_ptr(), // seqstart_k + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr total_q, total_k, b, diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 07cfa9a8f90..946b713ef94 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask, enable_alibi ? bias_enum::alibi : bias_enum::no_bias, has_lse, has_dropout, - false}; // do_fp8_static_quant + quant_scale_enum::no_scale}; // qscale_type } fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask, @@ -116,12 +116,18 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, k.data_ptr(), v.data_ptr(), alibi_slopes_ptr, // bias + nullptr, // q_descale_ptr + nullptr, // k_descale_ptr + nullptr, // v_descale_ptr has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_kpads + seqlens_q.data_ptr(), // seqstart_q_ptr + seqlens_k.data_ptr(), // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr total_q, total_k, b, @@ -131,8 +137,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, h, // nhead h_k, // nhead_k softmax_scale, // scale_s - 1, // scale_p - 1, // scale_o 0.0f, // logits_soft_cap stride_q, stride_k,