Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "csrc/composable_kernel"]
path = csrc/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = amd-master
[submodule "third_party/aiter"]
path = third_party/aiter
url = https://github.com/ROCm/aiter.git
1 change: 0 additions & 1 deletion csrc/composable_kernel
Submodule composable_kernel deleted from 13f6d6
5 changes: 5 additions & 0 deletions csrc/flash_attn_ck/ck_build_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// Generated by setup.py. Do not edit manually.
#pragma once

#define FLASHATTN_CK_GIT_COMMIT "unknown"
#define FLASHATTN_CK_USE_CURRENT_API 0
185 changes: 149 additions & 36 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#include "flash_common.hpp"

#include "ck_build_config.hpp"
#include "fmha_bwd.hpp"
#include "mask.hpp"
#include <variant>

fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
Expand All @@ -14,16 +16,38 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
bool enable_alibi,
bool deterministic)
{
#if FLASHATTN_CK_USE_CURRENT_API
return fmha_bwd_traits{
.seqlen_q = -1,
.seqlen_k = -1,
.batch = -1,
.max_seqlen_q = -1,
.max_seqlen_k = -1,
.hdim_q = head_size,
.hdim_v = head_size,
.nhead_q = -1,
.nhead_k = -1,
.data_type = std::move(dtype),
.is_group_mode = false,
.mask_type = mask.type,
.bias_type = enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
.has_dbias = false,
.has_dropout = has_dropout,
.is_store_randval = false,
.is_deterministic = deterministic,
};
#else
return fmha_bwd_traits{head_size,
head_size,
dtype,
false, // is_group_mode
false,
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
false, // has_dbias
false,
has_dropout,
false, // s_randval
false,
deterministic};
#endif
}

fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
Expand Down Expand Up @@ -98,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
// dq_acc: (batch_size, nheads, split, seqlen_q, hdim)
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t stride_dq_acc = dq_acc.stride(3);

float p_undrop = 1.0 - p_dropout;

Expand All @@ -119,81 +143,167 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
}

#if FLASHATTN_CK_USE_CURRENT_API
auto drop_seed_var = std::variant<std::pair<uint64_t, uint64_t>,
std::pair<const void*, const void*>>{
std::pair<uint64_t, uint64_t>{0, 0}};
if (drop_seed_offset.first != nullptr && drop_seed_offset.second != nullptr) {
drop_seed_var = std::pair<const void*, const void*>{drop_seed_offset.first, drop_seed_offset.second};
}
return fmha_bwd_args{
.q_ptr = q.data_ptr(),
.k_ptr = k.data_ptr(),
.v_ptr = v.data_ptr(),
.bias_ptr = alibi_slopes_ptr,
.o_ptr = out.data_ptr(),
.lse_ptr = softmax_lse.data_ptr(),
.do_ptr = dout.data_ptr(),
.d_ptr = d.data_ptr(),
.rand_val_ptr = nullptr,
.dq_ptr = dq.data_ptr(),
.dk_ptr = dk.data_ptr(),
.dv_ptr = dv.data_ptr(),
.dbias_ptr = nullptr,
.dq_acc_ptr = dq_acc.data_ptr(),
.seqstart_q_ptr = nullptr,
.seqstart_k_ptr = nullptr,
.seqlen_q_ptr = nullptr,
.seqlen_k_ptr = nullptr,
.cu_seqlen_q_ptr = nullptr,
.cu_seqlen_k_ptr = nullptr,
.seqlen_q = seqlen_q,
.seqlen_k = seqlen_k,
.batch = b,
.max_seqlen_q = seqlen_q,
.max_seqlen_k = seqlen_k,
.hdim_q = hdim,
.hdim_v = hdim,
.nhead_q = h,
.nhead_k = h_k,
.scale = softmax_scale,
.stride_q = stride_q,
.stride_k = stride_k,
.stride_v = stride_v,
.stride_bias = stride_alibi_slopes,
.stride_o = stride_o,
.stride_randval = 0,
.stride_do = stride_do,
.stride_dq_acc = stride_dq_acc,
.stride_dq = stride_dq,
.stride_dk = stride_dk,
.stride_dv = stride_dv,
.stride_dbias = 0,
.nhead_stride_q = nhead_stride_q,
.nhead_stride_k = nhead_stride_k,
.nhead_stride_v = nhead_stride_v,
.nhead_stride_bias = 0,
.nhead_stride_o = nhead_stride_o,
.nhead_stride_randval = 0,
.nhead_stride_do = nhead_stride_do,
.nhead_stride_lsed = nhead_stride_lse,
.nhead_stride_dq_acc = nhead_stride_dq_acc,
.nhead_stride_dq = nhead_stride_dq,
.nhead_stride_dk = nhead_stride_dk,
.nhead_stride_dv = nhead_stride_dv,
.nhead_stride_dbias = 0,
.batch_stride_q = batch_stride_q,
.batch_stride_k = batch_stride_k,
.batch_stride_v = batch_stride_v,
.batch_stride_bias = 0,
.batch_stride_o = batch_stride_o,
.batch_stride_randval = 0,
.batch_stride_do = batch_stride_do,
.batch_stride_lsed = batch_stride_lse,
.batch_stride_dq_acc = batch_stride_dq_acc,
.batch_stride_dq = batch_stride_dq,
.batch_stride_dk = batch_stride_dk,
.batch_stride_dv = batch_stride_dv,
.batch_stride_dbias = 0,
.split_stride_dq_acc = split_stride_dq_acc,
.window_size_left = mask.left,
.window_size_right = mask.right,
.mask_type = static_cast<ck_tile::index_t>(mask.type),
.p_drop = p_dropout,
.p_undrop = p_undrop,
.drop_seed_offset = std::move(drop_seed_var),
};
#else
return fmha_bwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
alibi_slopes_ptr, // bias
alibi_slopes_ptr,
out.data_ptr(),
softmax_lse.data_ptr(),
dout.data_ptr(),
d.data_ptr(),
nullptr, // rand_val
nullptr,
dq.data_ptr(),
dk.data_ptr(),
dv.data_ptr(),
nullptr, // dbias
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
nullptr,
dq_acc.data_ptr(),
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim, // hdim_q
hdim, // hdim_v
h, // nhead
h_k, // nhead_k
seqlen_q,
seqlen_k,
hdim,
hdim,
h,
h_k,
softmax_scale,
stride_q,
stride_k,
stride_v,
stride_alibi_slopes,
stride_o,
0, // stride_randval
0,
stride_do,
stride_dq_acc,
stride_dq,
stride_dk,
stride_dv,
0, // stride_dbias, FA without bias
0,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
0,
nhead_stride_o,
0, // nhead_stride_randval
0,
nhead_stride_do,
nhead_stride_lse,
nhead_stride_dq_acc,
nhead_stride_dq,
nhead_stride_dk,
nhead_stride_dv,
0, // nhead_stride_dbias, FA without dbias
0,
batch_stride_q,
batch_stride_k,
batch_stride_v,
0 , // batch_stride_bias, FA without bias
0,
batch_stride_o,
0, // batch_stride_randval
0,
batch_stride_do,
batch_stride_lse,
batch_stride_dq_acc,
batch_stride_dq,
batch_stride_dk,
batch_stride_dv,
0 , // batch_stride_dbias, FA without dbias
0,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_dropout,
p_undrop,
drop_seed_offset};
#endif
}

std::vector<at::Tensor>
Expand Down Expand Up @@ -323,11 +433,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({batch_size, num_heads, 1, seqlen_q, head_size}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat));
}

at::Tensor dk_expanded, dv_expanded;
Expand Down Expand Up @@ -358,8 +468,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
}

if (seqlen_q > 0) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
std::pair<uint64_t*, uint64_t*> drop_seed_offset = {nullptr, nullptr};
if (is_dropout) {
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
}
ck_tile::stream_config stream_config{stream};

auto traits =
Expand Down
Loading