Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 120 additions & 127 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,37 @@

#include "flash_common.hpp"

#include "fmha_bwd.hpp"
#include "mha_bwd.h"
#include "mask.hpp"

fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int head_size,
bool has_dropout,
bool enable_alibi,
bool deterministic)
{
return fmha_bwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
has_dropout,
false, // s_randval
deterministic};
}

fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
std::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor dq_acc,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
std::string dtype,
bool enable_alibi,
bool has_dropout,
bool deterministic,
// sizes
const int b,
const int seqlen_q,
const int seqlen_k,
const int h,
const int h_k,
const int hdim,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
std::optional<at::Tensor> &alibi_slopes_,
const at::Tensor out,
const at::Tensor softmax_lse,
const at::Tensor dout,
at::Tensor dq_acc,
at::Tensor d,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
float softmax_scale,
float p_dropout,
std::pair<uint64_t*, uint64_t*> drop_seed_offset)
{
// q: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_q = q.stride(0);
Expand Down Expand Up @@ -80,9 +65,6 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_do = dout.stride(1);
ck_tile::index_t nhead_stride_do = dout.stride(2);

// d: (batch_size, nheads, seqlen_q)
// CK assume d share the same stride with lse

// dq: (batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t batch_stride_dq = dq.stride(0);
ck_tile::index_t stride_dq = dq.stride(1);
Expand Down Expand Up @@ -115,85 +97,95 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
alibi_slopes_ptr = alibi_slopes.data_ptr();
// alibi_slopes:(batch_size, nheads) or (nhead)
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
drop_seed_offset};
return aiter::mha_bwd_args{false, // use_asm_v3
false, // v3_atomic_fp32
1, // v3_bf16_cvt
false, // v3_api_check
hdim, // hdim_q
hdim, // hdim_v
dtype,
false, // is_group_mode
static_cast<int>(mask.type),
static_cast<int>(enable_alibi ? bias_enum::alibi : bias_enum::no_bias),
false, // has_dbias
has_dropout,
false, // is_store_randval
deterministic,
q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(),
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
h, // nhead_q
h_k, // nhead_k
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias
nhead_stride_o,
0, // nhead_stride_randval
nhead_stride_do,
nhead_stride_lse,
static_cast<int64_t>(nhead_stride_dq_acc),
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias
batch_stride_o,
0, // batch_stride_randval
batch_stride_do,
batch_stride_lse,
static_cast<int64_t>(batch_stride_dq_acc),
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0, // batch_stride_dbias
split_stride_dq_acc,
mask.left,
mask.right,
p_dropout,
p_undrop,
drop_seed_offset};
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -362,12 +354,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};

auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);

auto args =
get_ck_fmha_bwd_args(
mask,
q_dtype_str,
alibi_slopes_.has_value(),
is_dropout,
deterministic,
batch_size,
seqlen_q,
seqlen_k,
Expand All @@ -390,7 +383,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
p_dropout,
drop_seed_offset);

float t = fmha_bwd(traits, args, stream_config);
float t = aiter::mha_bwd(args, stream_config);
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
Expand Down
Loading