diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index ebf7f6f28e..266012574f 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -202,6 +202,7 @@ def gen_fmha_v3_fwd_fake_tensors( window_size_right: int, return_softmax_lse: bool, return_dropout_randval: bool, + how_v3_bf16_cvt: int, out: Optional[Tensor] = None, bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, @@ -226,6 +227,7 @@ def fmha_v3_fwd( window_size_right: int, return_softmax_lse: bool, return_dropout_randval: bool, + how_v3_bf16_cvt: int, out: Optional[Tensor] = None, bias: Optional[Tensor] = None, alibi_slopes: Optional[Tensor] = None, @@ -1176,6 +1178,7 @@ def _flash_attn_forward( alibi_slopes: Optional[torch.Tensor], return_lse: bool, return_softmax: bool, + how_v3_bf16_cvt: Optional[int] = 1, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1229,6 +1232,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): window_size_right, return_lse, return_softmax, + how_v3_bf16_cvt, None, bias, alibi_slopes, @@ -1626,6 +1630,7 @@ def forward( alibi_slopes=alibi_slopes, return_lse=return_lse, return_softmax=return_softmax and dropout_p > 0, + how_v3_bf16_cvt=how_v3_bf16_cvt, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, ) @@ -1742,6 +1747,7 @@ def flash_attn_func( deterministic=True, return_lse=False, return_attn_probs=False, + how_v3_bf16_cvt=1, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, ): @@ -1811,7 +1817,7 @@ def flash_attn_func( return_attn_probs, torch.is_grad_enabled(), True, # is_v3_atomic_fp32 - 1, # how_v3_bf16_cvt + how_v3_bf16_cvt, cu_seqlens_q, cu_seqlens_kv, ) @@ -2352,6 +2358,7 @@ def flash_attn_varlen_func( deterministic=False, return_lse=False, return_attn_probs=False, + how_v3_bf16_cvt=1, block_table=None, out=None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, @@ -2445,7 +2452,7 @@ def flash_attn_varlen_func( cu_seqlens_q_padded, cu_seqlens_k_padded, True, - 1, + how_v3_bf16_cvt, ) diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 7926085f17..74e3a9638e 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -614,6 +614,7 @@ namespace py = pybind11; py::arg("window_size_right"), \ py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ + py::arg("how_v3_bf16_cvt"), \ py::arg("out") = std::nullopt, \ py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ diff --git a/csrc/include/torch/mha_v3_fwd.h b/csrc/include/torch/mha_v3_fwd.h index e1b0543d48..9ec33136fc 100644 --- a/csrc/include/torch/mha_v3_fwd.h +++ b/csrc/include/torch/mha_v3_fwd.h @@ -15,6 +15,7 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] int window_size_right, bool return_softmax_lse, bool return_dropout_randval, + int how_v3_bf16_cvt, std::optional out_, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] diff --git a/csrc/py_itfs_cu/asm_mha_fwd.cu b/csrc/py_itfs_cu/asm_mha_fwd.cu index f5cbf24d86..33cd53ca6b 100644 --- a/csrc/py_itfs_cu/asm_mha_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_fwd.cu @@ -149,6 +149,7 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] int window_size_right, bool return_softmax_lse, bool return_dropout_randval, + int how_v3_bf16_cvt, std::optional out_, // [b, sq, hq, d_v] std::optional bias_, // [sq, sk] std::optional alibi_slopes_, // [hq] or [b, hq] @@ -316,7 +317,8 @@ std::vector fmha_v3_fwd(at::Tensor &q, // [b, sq, hq, d] mask.type, bias_type, has_lse, - true); + true, + how_v3_bf16_cvt); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } else { diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 1b2ea0d9cc..669196cdef 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -104,10 +104,11 @@ def run_ck( bias, alibi_slopes, deterministic, - return_lse, - return_attn_probs, - cu_seqlens_q, - cu_seqlens_kv, + return_lse=return_lse, + return_attn_probs=return_attn_probs, + how_v3_bf16_cvt=1, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, num_rotate_args=1, )