Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,10 +1500,11 @@ def can_impl_fmha_v3_bwd_gfx950():
ret &= dbias is None
ret &= dropout_p == 0.0
ret &= not deterministic or is_950_1block
ret &= hdim_q == hdim_v
ret &= nhead_q % nhead_k == 0
ret &= hdim_q > 64 and hdim_q <= 128 and hdim_q % 8 == 0

ret &= (
(hdim_q > 64 and hdim_q <= 128)
or (hdim_q == 192 and hdim_v == 128 and nmask)
) and hdim_q % 8 == 0
return ret

can_impl_fmha_v3_bwd_ |= can_impl_fmha_v3_bwd_gfx950()
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ struct fmha_bwd_v3_traits
int ts_dq = 64;
};

template <ck_tile::index_t HDim_,
template <ck_tile::index_t HDim_q_,
ck_tile::index_t HDim_v_,
typename DataType_,
int mask_type_,
bool kIsAtomic32_,
Expand All @@ -397,7 +398,8 @@ template <ck_tile::index_t HDim_,
GPUArch GPUArch_>
struct fmha_bwd_dq_dk_dv_v3_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
static constexpr ck_tile::index_t HDim_q = HDim_q_;
static constexpr ck_tile::index_t HDim_v = HDim_v_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr int mask_type = mask_type_;
static constexpr bool kIsAtomic32 = kIsAtomic32_;
Expand Down
8 changes: 4 additions & 4 deletions csrc/py_itfs_ck/mha_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,14 @@ mha_bwd(const at::Tensor &dout, // [b, sq, hq, d_v]
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_v <= 128 ? 128 : 64;
const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
if (mask.type == mask_enum::no_mask)
dq_accum = torch::empty({nsplits, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::empty({nsplits, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat));
else // Some block may be skipped with causal mask and dq are not set to zeros
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_q}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down
4 changes: 2 additions & 2 deletions csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v]
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, total_q, num_heads, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({1, total_q, num_heads, head_size_q}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size_q}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down
11 changes: 4 additions & 7 deletions csrc/py_itfs_cu/asm_mha_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,13 @@ std::vector<at::Tensor> fmha_v3_bwd(const at::Tensor &dout, // [b, sq, h

if (!deterministic) {
if (is_v3_atomic_fp32) {
dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({1, batch_size, num_heads, seqlen_q, head_size_q}, opts.dtype(at::kFloat));
} else {
// When atomic16, padding dq_accum seqlen to 16x, head dim to 128
// When atomic16, padding dq_accum seqlen to 16x, head dim to 128/192
// In this case, dq_accum could have any layout, we set it to be `bhsd`
dq_accum = torch::zeros({1, batch_size, num_heads, (seqlen_q + 15) / 16 * 16, 128}, opts.dtype(q_dtype));
int padded_head_size_q = head_size_q == 192? 192: 128;
dq_accum = torch::zeros({1, batch_size, num_heads, (seqlen_q + 15) / 16 * 16, padded_head_size_q}, opts.dtype(q_dtype));
}
} else {
const ck_tile::index_t kN0 = head_size_v <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down
8 changes: 2 additions & 6 deletions csrc/py_itfs_cu/asm_mha_varlen_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fmha_bwd_args get_asm_fmha_varlen_bwd_args(const mask_info &mask,
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t stride_dq_acc;
// For atomic32, dq_acc layout is (1, num_heads, total_q, head_size_v)
// For atomic32, dq_acc layout is (1, num_heads, total_q, head_size_q)
// For atomic16, dq_acc layout is (1, batch_size, num_heads, (max_seqlen_q + 15) / 16 * 16, 128)
if (is_v3_atomic_fp32) {
split_stride_dq_acc = dq_acc.stride(0);
Expand Down Expand Up @@ -338,16 +338,12 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v

if (!deterministic) {
if (is_v3_atomic_fp32) {
dq_accum = torch::zeros({1, num_heads, total_q, head_size_v}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({1, num_heads, total_q, head_size_q}, opts.dtype(at::kFloat));
} else {
// When atomic16, padding dq_accum seqlen to 16x of max_seqlen_q, head dim to 128
// In this case, dq_accum could have any layout, we set it to be `bhsd`
dq_accum = torch::zeros({1, batch_size, num_heads, (max_seqlen_q + 15) / 16 * 16, 128}, opts.dtype(q_dtype));
}
} else {
const ck_tile::index_t kN0 = head_size_q <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, num_heads, total_q, head_size_v}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down
Loading