Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def gen_fmha_v3_fwd_fake_tensors(
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
how_v3_bf16_cvt: int,
out: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
Expand All @@ -226,6 +227,7 @@ def fmha_v3_fwd(
window_size_right: int,
return_softmax_lse: bool,
return_dropout_randval: bool,
how_v3_bf16_cvt: int,
out: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
alibi_slopes: Optional[Tensor] = None,
Expand Down Expand Up @@ -1176,6 +1178,7 @@ def _flash_attn_forward(
alibi_slopes: Optional[torch.Tensor],
return_lse: bool,
return_softmax: bool,
how_v3_bf16_cvt: Optional[int] = 1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -1229,6 +1232,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]):
window_size_right,
return_lse,
return_softmax,
how_v3_bf16_cvt,
None,
bias,
alibi_slopes,
Expand Down Expand Up @@ -1626,6 +1630,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_lse=return_lse,
return_softmax=return_softmax and dropout_p > 0,
how_v3_bf16_cvt=how_v3_bf16_cvt,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
Expand Down Expand Up @@ -1742,6 +1747,7 @@ def flash_attn_func(
deterministic=True,
return_lse=False,
return_attn_probs=False,
how_v3_bf16_cvt=1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
):
Expand Down Expand Up @@ -1811,7 +1817,7 @@ def flash_attn_func(
return_attn_probs,
torch.is_grad_enabled(),
True, # is_v3_atomic_fp32
1, # how_v3_bf16_cvt
how_v3_bf16_cvt,
cu_seqlens_q,
cu_seqlens_kv,
)
Expand Down Expand Up @@ -2352,6 +2358,7 @@ def flash_attn_varlen_func(
deterministic=False,
return_lse=False,
return_attn_probs=False,
how_v3_bf16_cvt=1,
block_table=None,
out=None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -2445,7 +2452,7 @@ def flash_attn_varlen_func(
cu_seqlens_q_padded,
cu_seqlens_k_padded,
True,
1,
how_v3_bf16_cvt,
)


Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ namespace py = pybind11;
py::arg("window_size_right"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("how_v3_bf16_cvt"), \
py::arg("out") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
Expand Down
1 change: 1 addition & 0 deletions csrc/include/torch/mha_v3_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
int window_size_right,
bool return_softmax_lse,
bool return_dropout_randval,
int how_v3_bf16_cvt,
std::optional<at::Tensor> out_, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
Expand Down
4 changes: 3 additions & 1 deletion csrc/py_itfs_cu/asm_mha_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
int window_size_right,
bool return_softmax_lse,
bool return_dropout_randval,
int how_v3_bf16_cvt,
std::optional<at::Tensor> out_, // [b, sq, hq, d_v]
std::optional<const at::Tensor> bias_, // [sq, sk]
std::optional<const at::Tensor> alibi_slopes_, // [hq] or [b, hq]
Expand Down Expand Up @@ -316,7 +317,8 @@ std::vector<at::Tensor> fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d]
mask.type,
bias_type,
has_lse,
true);
true,
how_v3_bf16_cvt);
TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
}
else {
Expand Down
9 changes: 5 additions & 4 deletions op_tests/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,11 @@ def run_ck(
bias,
alibi_slopes,
deterministic,
return_lse,
return_attn_probs,
cu_seqlens_q,
cu_seqlens_kv,
return_lse=return_lse,
return_attn_probs=return_attn_probs,
how_v3_bf16_cvt=1,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
num_rotate_args=1,
)

Expand Down