Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ def cmdGenFunc_mha_varlen_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
) -> dict[str, Any]:
md_name = "mha_varlen_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -1081,6 +1083,8 @@ def mha_varlen_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -1110,6 +1114,8 @@ def gen_fmha_v3_varlen_bwd_fake_tensor(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
):
return gen_mha_varlen_bwd_fake_tensors_common(
q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv
Expand All @@ -1131,8 +1137,6 @@ def fmha_v3_varlen_bwd(
softmax_lse: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
# cu_seqlens_q_padded: Tensor,
# cu_seqlens_k_padded: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
Expand All @@ -1150,6 +1154,8 @@ def fmha_v3_varlen_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -1952,10 +1958,6 @@ def _flash_attn_varlen_backward(
dv: Optional[torch.Tensor],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
# FIXME: this two args currently not support on ck side
# and has no host code on aiter side
# cu_seqlens_q_padded: Tensor,
# cu_seqlens_k_padded: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float,
Expand All @@ -1969,6 +1971,8 @@ def _flash_attn_varlen_backward(
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
zero_tensors: bool = False,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
) -> torch.Tensor:

(_, nhead_q, hdim_q) = q.shape
Expand Down Expand Up @@ -2077,8 +2081,6 @@ def can_impl_fmha_v3_bwd_gfx950():
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
# cu_seqlens_q_padded,
# cu_seqlens_k_padded,
max_seqlen_q,
max_seqlen_k,
dropout_p,
Expand All @@ -2096,6 +2098,8 @@ def can_impl_fmha_v3_bwd_gfx950():
alibi_slopes,
rng_state,
None,
cu_seqlens_q_padded,
cu_seqlens_k_padded,
)
else:
(
Expand Down Expand Up @@ -2127,6 +2131,8 @@ def can_impl_fmha_v3_bwd_gfx950():
alibi_slopes,
rng_state,
None,
cu_seqlens_q_padded,
cu_seqlens_k_padded,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
Expand Down Expand Up @@ -2216,6 +2222,8 @@ def forward(
ctx.head_size_q_og = head_size_q_og
ctx.is_v3_atomic_fp32 = is_v3_atomic_fp32
ctx.how_v3_bf16_cvt = how_v3_bf16_cvt
ctx.cu_seqlens_q_padded = cu_seqlens_q_padded
ctx.cu_seqlens_k_padded = cu_seqlens_k_padded

out = out_padded[..., :head_size_v_og]

Expand Down Expand Up @@ -2272,6 +2280,8 @@ def backward(ctx, dout, *args):
rng_state=rng_state,
is_v3_atomic_fp32=ctx.is_v3_atomic_fp32,
how_v3_bf16_cvt=ctx.how_v3_bf16_cvt,
cu_seqlens_q_padded=ctx.cu_seqlens_q_padded,
cu_seqlens_k_padded=ctx.cu_seqlens_k_padded,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
Expand Down
10 changes: 5 additions & 5 deletions csrc/include/mha_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ __attribute__((visibility("default"))) float mha_bwd(mha_bwd_args args,
int how_v3_bf16_cvt,
const void* seqlen_q_padded = nullptr,
const void* seqlen_k_padded = nullptr,
bool is_v3_api_check = false);
bool is_v3_api_check = false);

struct __attribute__((packed)) fmha_bwd_v3_args
{
Expand Down Expand Up @@ -364,9 +364,9 @@ struct __attribute__((packed)) fmha_bwd_dq_shuffle_args
p3 _p9;
unsigned int head_dim;
p3 _p10;
const void *ptr_qseq;
const void* ptr_qseq;
p2 _p11;
const void *ptr_qseq_padded;
const void* ptr_qseq_padded;
p2 _p12;
unsigned int max_seqlen_dq;
p3 _p13;
Expand Down Expand Up @@ -422,7 +422,7 @@ float fmha_bwd_v3(mha_bwd_traits t,
const ck_tile::stream_config& s,
const void* seqlen_q_padded = nullptr,
const void* seqlen_k_padded = nullptr,
bool is_v3_api_check = false);
bool is_v3_api_check = false);
}

namespace gfx950 {
Expand All @@ -431,6 +431,6 @@ float fmha_bwd_v3(mha_bwd_traits t,
const ck_tile::stream_config& s,
const void* seqlen_q_padded = nullptr,
const void* seqlen_k_padded = nullptr,
bool is_v3_api_check = false);
bool is_v3_api_check = false);
}
} // namespace aiter
112 changes: 58 additions & 54 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,34 +539,36 @@ namespace py = pybind11;
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_ASM_PYBIND \
m.def("fmha_v3_varlen_bwd", \
&aiter::torch_itfs::fmha_v3_varlen_bwd, \
py::arg("dout"), \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("out"), \
py::arg("softmax_lse"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("deterministic"), \
py::arg("is_v3_atomic_fp32"), \
py::arg("how_v3_bf16_cvt"), \
py::arg("dq") = std::nullopt, \
py::arg("dk") = std::nullopt, \
py::arg("dv") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);
#define MHA_VARLEN_BWD_ASM_PYBIND \
m.def("fmha_v3_varlen_bwd", \
&aiter::torch_itfs::fmha_v3_varlen_bwd, \
py::arg("dout"), \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("out"), \
py::arg("softmax_lse"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("deterministic"), \
py::arg("is_v3_atomic_fp32"), \
py::arg("how_v3_bf16_cvt"), \
py::arg("dq") = std::nullopt, \
py::arg("dk") = std::nullopt, \
py::arg("dv") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt, \
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);

#define MHA_BWD_PYBIND \
m.def("mha_bwd", \
Expand Down Expand Up @@ -657,32 +659,34 @@ namespace py = pybind11;
py::arg("alibi_slopes") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MHA_VARLEN_BWD_PYBIND \
m.def("mha_varlen_bwd", \
&aiter::torch_itfs::mha_varlen_bwd, \
py::arg("dout"), \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("out"), \
py::arg("softmax_lse"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("deterministic"), \
py::arg("dq") = std::nullopt, \
py::arg("dk") = std::nullopt, \
py::arg("dv") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt);
#define MHA_VARLEN_BWD_PYBIND \
m.def("mha_varlen_bwd", \
&aiter::torch_itfs::mha_varlen_bwd, \
py::arg("dout"), \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("out"), \
py::arg("softmax_lse"), \
py::arg("cu_seqlens_q"), \
py::arg("cu_seqlens_k"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("deterministic"), \
py::arg("dq") = std::nullopt, \
py::arg("dk") = std::nullopt, \
py::arg("dv") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("rng_state") = std::nullopt, \
py::arg("gen") = std::nullopt, \
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);

#define MOE_CK_2STAGES_PYBIND \
m.def("ck_moe_stage1", \
Expand Down
8 changes: 3 additions & 5 deletions csrc/include/torch/mha_v3_varlen_bwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v]
const at::Tensor& softmax_lse, // [b, hq, sq]
const at::Tensor& cu_seqlens_q, // [b+1]
const at::Tensor& cu_seqlens_k, // [b+1]
// FIXME: this two args currently not support on ck side
// and has no host code on aiter side
// const at::Tensor& cu_seqlens_q_padded, // [b+1]
// const at::Tensor& cu_seqlens_k_padded, // [b+1]
const int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
Expand All @@ -34,7 +30,9 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v]
std::optional<at::Tensor> dv_, // [total_k, hk, d_v]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state_,
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_,
std::optional<const at::Tensor> cu_seqlens_q_padded = std::nullopt,
std::optional<const at::Tensor> cu_seqlens_k_padded = std::nullopt);

} // namespace torch_itfs
} // namespace aiter
7 changes: 5 additions & 2 deletions csrc/include/torch/mha_varlen_bwd.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>

namespace aiter {
Expand Down Expand Up @@ -28,6 +28,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d]
std::optional<at::Tensor> dv, // [total_k, hk, d]
std::optional<const at::Tensor> alibi_slopes, // [hq] or [b, hq]
std::optional<const at::Tensor> rng_state,
std::optional<at::Generator> gen);
std::optional<at::Generator> gen,
std::optional<const at::Tensor> cu_seqlens_q_padded, // [b+1]
std::optional<const at::Tensor> cu_seqlens_k_padded // [b+1]
);
} // namespace torch_itfs
} // namespace aiter
9 changes: 6 additions & 3 deletions csrc/py_itfs_ck/mha_bwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,20 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
dv.data_ptr(),
dbias_ptr,
dq_acc.data_ptr(), // dq_acc
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
nullptr, // cu_seqlen_q_ptr
nullptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
hdim_q, // hdim_q
hdim_v, // hdim_v
h, // nhead
h, // nhead_q
h_k, // nhead_k
softmax_scale,
stride_q,
Expand Down
17 changes: 8 additions & 9 deletions csrc/py_itfs_ck/mha_fwd_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,19 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
has_lse ? softmax_lse.data_ptr() : nullptr,
out.data_ptr(),
cu_seqlen_q_ptr,
cu_seqlen_kv_ptr,
nullptr, // seqstart_q
nullptr, // seqstart_k
nullptr,
nullptr, // seqstart_padded_q_ptr
nullptr, // seqstart_padded_k_ptr
nullptr, // seqstart_q_ptr
nullptr, // seqstart_k_ptr
nullptr, // seqlen_q_ptr
nullptr, // seqlen_k_ptr
cu_seqlen_q_ptr, // cu_seqlen_q_ptr
cu_seqlen_kv_ptr, // cu_seqlen_k_ptr
seqlen_q,
seqlen_k,
b,
seqlen_q, // max_seqlen_q
d, // hdim_q
d_v, // hdim_v
h, // nhead
h, // nhead_q
h_k, // nhead_k
softmax_scale, // scale_s
1, // scale_p
Expand Down Expand Up @@ -139,7 +138,7 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
0,
0, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};
Expand Down
Loading