Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/composable_kernel
3 changes: 3 additions & 0 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
Expand Down
16 changes: 10 additions & 6 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
quant_scale_enum::no_scale}; // qscale_type
}

fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
Expand Down Expand Up @@ -95,12 +95,18 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
nullptr, // q_descale_ptr
nullptr, // k_descale_ptr
nullptr, // v_descale_ptr
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr,
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
Expand All @@ -110,8 +116,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
h, // nhead
h_k, // nhead_k
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
Expand Down
3 changes: 3 additions & 0 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
dq_acc.data_ptr(), // dq_acc
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
total_q,
total_k,
b,
Expand Down
16 changes: 10 additions & 6 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
has_dropout,
false}; // do_fp8_static_quant
quant_scale_enum::no_scale}; // qscale_type
}

fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask,
Expand Down Expand Up @@ -116,12 +116,18 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
nullptr, // q_descale_ptr
nullptr, // k_descale_ptr
nullptr, // v_descale_ptr
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
seqlens_q.data_ptr(), // seqstart_q
seqlens_k.data_ptr(), // seqstart_k
nullptr, // seqlen_kpads
seqlens_q.data_ptr(), // seqstart_q_ptr
seqlens_k.data_ptr(), // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
total_q,
total_k,
b,
Expand All @@ -131,8 +137,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
h, // nhead
h_k, // nhead_k
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
Expand Down