Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/composable_kernel
58 changes: 37 additions & 21 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@

fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
std::string dtype,
int seqlen_q,
int seqlen_k,
int batch,
int head_size,
int nhead_q,
int nhead_k,
bool has_dropout,
bool enable_alibi,
bool deterministic)
{
return fmha_bwd_traits{head_size,
head_size,
return fmha_bwd_traits{seqlen_q,
seqlen_k,
batch,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
head_size, // hdim_q
head_size, // hdim_k
nhead_q,
nhead_k,
dtype,
false, // is_group_mode
mask.type,
Expand Down Expand Up @@ -98,11 +110,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2);

// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
// dq_acc: (batch_size, nheads, split, seqlen_q, hdim)
ck_tile::long_index_t batch_stride_dq_acc = dq_acc.stride(0);
ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(2);
ck_tile::index_t stride_dq_acc = dq_acc.stride(3);

float p_undrop = 1.0 - p_dropout;

Expand Down Expand Up @@ -222,7 +234,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
#endif
if (is_causal) { window_size_right = 0; }

bool is_dropout = p_dropout > 0.0;
const bool is_dropout = p_dropout > 0.0;
#ifdef HIPIFY_V2
auto stream = at::cuda::getCurrentCUDAStream().stream();
#else
Expand All @@ -238,7 +250,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");

std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
Expand Down Expand Up @@ -316,19 +328,26 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv = torch::empty_like(v);
}

const auto traits = get_ck_fmha_bwd_traits(
mask,
q_dtype_str,
seqlen_q,
seqlen_k,
batch_size,
head_size,
num_heads,
num_heads_k,
is_dropout,
alibi_slopes_.has_value(),
deterministic);
fmha_bwd_launcher launcher(traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;

at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
}
at::Tensor dq_accum = torch::zeros({batch_size, num_heads, nsplits, seqlen_q, head_size}, opts.dtype(at::kFloat));

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
Expand Down Expand Up @@ -362,9 +381,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};

auto traits =
get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);

auto args =
get_ck_fmha_bwd_args(
mask,
Expand Down
18 changes: 17 additions & 1 deletion csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
nullptr, // block_scale_seqstart_q_ptr
nullptr, // block_scale_seqstart_k_ptr
nullptr, // seqstart_v_scale_ptr
nullptr, // sink_ptr
seqlen_q,
seqlen_k,
b,
Expand All @@ -123,27 +127,39 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
stride_alibi_slopes,
stride_randval,
stride_o,
0, // stride_q_descale
0, // stride_k_descale
0, // stride_v_descale
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
0, // nhead_stride_q_descale
0, // nhead_stride_k_descale
0, // nhead_stride_v_descale
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
0, // batch_stride_q_descale
0, // batch_stride_k_descale
0, // batch_stride_v_descale
mask.left,
mask.right,
0, // sink_size
static_cast<ck_tile::index_t>(mask.type),
0, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};
drop_seed_offset,
0, // block_scale_size_q
0}; // block_scale_size_kv
}

std::vector<at::Tensor>
Expand Down
11 changes: 7 additions & 4 deletions csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,
return fmha_fwd_splitkv_traits{head_size,
head_size,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
false, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
false}; // do_fp8_static_quant
false, // do_fp8_static_quant
false}; // has_sink
}

fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
Expand Down Expand Up @@ -177,6 +178,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
args.o_acc_ptr = out_acc.data_ptr();
args.lse_ptr = nullptr;
args.o_ptr = out.data_ptr();
args.sink_ptr = nullptr;

if (block_table_.has_value())
{
Expand Down Expand Up @@ -261,6 +263,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,

args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.sink_size = 0;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);

return args;
Expand Down
63 changes: 41 additions & 22 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,27 @@

fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
std::string dtype,
int seqlen_q,
int seqlen_k,
int batch,
int max_seqlen_q,
int max_seqlen_k,
int head_size,
int nhead_q,
int nhead_k,
bool has_dropout,
bool enable_alibi,
bool deterministic)
{
return fmha_bwd_traits{head_size,
head_size,
return fmha_bwd_traits{seqlen_q,
seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
head_size, // hdim_q
head_size, // hdim_k
nhead_q,
nhead_k,
dtype,
true, // is_group_mode
mask.type,
Expand All @@ -25,7 +39,6 @@ fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
false, // s_randval
deterministic};
}

fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
// sizes
const int b,
Expand Down Expand Up @@ -104,11 +117,11 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(0);
ck_tile::index_t nhead_stride_dv = dv.stride(1);

// dq_acc: (split, total_q, nheads, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = 0;
ck_tile::index_t stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
// dq_acc: (nheads, split, total_q, hdim)
ck_tile::long_index_t batch_stride_dq_acc = 0;
ck_tile::long_index_t nhead_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);

float p_undrop = 1.0 - p_dropout;

Expand Down Expand Up @@ -233,7 +246,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
#endif
if (is_causal) { window_size_right = 0; }

bool is_dropout = p_dropout > 0.0;
const bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();

auto q_dtype = q.dtype();
Expand All @@ -247,7 +260,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");

std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
const std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
Expand Down Expand Up @@ -330,19 +343,28 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
dv = torch::empty_like(v);
}

const auto traits = get_ck_fmha_varlen_bwd_traits(
mask,
q_dtype_str,
total_q,
total_k,
batch_size,
max_seqlen_q,
max_seqlen_k,
head_size,
num_heads,
num_heads_k,
is_dropout,
alibi_slopes_.has_value(),
deterministic);
fmha_bwd_launcher launcher(traits);
const ck_tile::index_t nsplits = launcher.dq_acc_splits;

at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor dq_accum;

if (!deterministic) {
dq_accum = torch::zeros({1, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
} else {
const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
}
at::Tensor dq_accum = torch::zeros({num_heads, nsplits, total_q, head_size}, opts.dtype(at::kFloat));

at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
Expand Down Expand Up @@ -385,9 +407,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
ck_tile::stream_config stream_config{stream};

auto traits =
get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);

auto args =
get_ck_fmha_varlen_bwd_args(
mask,
Expand Down
23 changes: 21 additions & 2 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
false}; // do_fp8_static_quant
false, // do_fp8_static_quant
false}; // has_sink
}

fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
Expand Down Expand Up @@ -128,6 +129,10 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_kv_ptr
nullptr, // block_scale_seqstart_q_ptr
nullptr, // block_scale_seqstart_k_ptr
nullptr, // seqstart_v_scale_ptr
nullptr, // sink_ptr
total_q,
total_k,
b,
Expand All @@ -144,27 +149,39 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
stride_alibi_slopes,
stride_randval,
stride_o,
0, // stride_q_descale
0, // stride_k_descale
0, // stride_v_descale
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
0, // nhead_stride_bias, FA without bias
nhead_stride_randval,
nhead_stride_lse,
nhead_stride_o,
0, // nhead_stride_q_descale
0, // nhead_stride_k_descale
0, // nhead_stride_v_descale
batch_stride_q,
batch_stride_k,
batch_stride_v,
0, // batch_stride_bias, FA without bias
batch_stride_randval,
batch_stride_lse,
batch_stride_o,
0, // batch_stride_q_descale
0, // batch_stride_k_descale
0, // batch_stride_v_descale
mask.left,
mask.right,
0, // sink_size
static_cast<ck_tile::index_t>(mask.type),
0, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};
drop_seed_offset,
0, // block_scale_size_q
0}; // block_scale_size_kv
}

fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
Expand Down Expand Up @@ -210,6 +227,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
args.o_acc_ptr = out_acc.data_ptr();
args.lse_ptr = nullptr;
args.o_ptr = out.data_ptr();
args.sink_ptr = nullptr;

if (block_table_.has_value())
{
Expand Down Expand Up @@ -293,6 +311,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,

args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.sink_size = 0;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);

return args;
Expand Down
Loading