diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 515e283091..7c6430eca0 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 515e28309153ae8ab6fa3cbed81b44e2c01c43cd +Subproject commit 7c6430eca04e62454217630ae2a0bbd70ff50a00 diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 26b591d348..ebf7f6f28e 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -848,6 +848,8 @@ def cmdGenFunc_mha_varlen_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + cu_seqlens_q_padded: Optional[Tensor] = None, + cu_seqlens_k_padded: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -1081,6 +1083,8 @@ def mha_varlen_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + cu_seqlens_q_padded: Optional[Tensor] = None, + cu_seqlens_k_padded: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -1110,6 +1114,8 @@ def gen_fmha_v3_varlen_bwd_fake_tensor( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + cu_seqlens_q_padded: Optional[Tensor] = None, + cu_seqlens_k_padded: Optional[Tensor] = None, ): return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1131,8 +1137,6 @@ def fmha_v3_varlen_bwd( softmax_lse: Tensor, cu_seqlens_q: Tensor, cu_seqlens_k: Tensor, - # cu_seqlens_q_padded: Tensor, - # cu_seqlens_k_padded: Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, @@ -1150,6 +1154,8 @@ def fmha_v3_varlen_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + cu_seqlens_q_padded: Optional[Tensor] = None, + cu_seqlens_k_padded: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -1952,10 +1958,6 @@ def _flash_attn_varlen_backward( dv: Optional[torch.Tensor], cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, - # FIXME: this two args currently not support on ck side - # and has no host code on aiter side - # cu_seqlens_q_padded: Tensor, - # cu_seqlens_k_padded: Tensor, max_seqlen_q: int, max_seqlen_k: int, dropout_p: float, @@ -1969,6 +1971,8 @@ def _flash_attn_varlen_backward( is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, zero_tensors: bool = False, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_k_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: (_, nhead_q, hdim_q) = q.shape @@ -2077,8 +2081,6 @@ def can_impl_fmha_v3_bwd_gfx950(): softmax_lse, cu_seqlens_q, cu_seqlens_k, - # cu_seqlens_q_padded, - # cu_seqlens_k_padded, max_seqlen_q, max_seqlen_k, dropout_p, @@ -2096,6 +2098,8 @@ def can_impl_fmha_v3_bwd_gfx950(): alibi_slopes, rng_state, None, + cu_seqlens_q_padded, + cu_seqlens_k_padded, ) else: ( @@ -2127,6 +2131,8 @@ def can_impl_fmha_v3_bwd_gfx950(): alibi_slopes, rng_state, None, + cu_seqlens_q_padded, + cu_seqlens_k_padded, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -2216,6 +2222,8 @@ def forward( ctx.head_size_q_og = head_size_q_og ctx.is_v3_atomic_fp32 = is_v3_atomic_fp32 ctx.how_v3_bf16_cvt = how_v3_bf16_cvt + ctx.cu_seqlens_q_padded = cu_seqlens_q_padded + ctx.cu_seqlens_k_padded = cu_seqlens_k_padded out = out_padded[..., :head_size_v_og] @@ -2272,6 +2280,8 @@ def backward(ctx, dout, *args): rng_state=rng_state, is_v3_atomic_fp32=ctx.is_v3_atomic_fp32, how_v3_bf16_cvt=ctx.how_v3_bf16_cvt, + cu_seqlens_q_padded=ctx.cu_seqlens_q_padded, + cu_seqlens_k_padded=ctx.cu_seqlens_k_padded, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] diff --git a/csrc/include/mha_bwd.h b/csrc/include/mha_bwd.h index ec4315a1a7..635aaf9117 100644 --- a/csrc/include/mha_bwd.h +++ b/csrc/include/mha_bwd.h @@ -61,7 +61,7 @@ __attribute__((visibility("default"))) float mha_bwd(mha_bwd_args args, int how_v3_bf16_cvt, const void* seqlen_q_padded = nullptr, const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); + bool is_v3_api_check = false); struct __attribute__((packed)) fmha_bwd_v3_args { @@ -364,9 +364,9 @@ struct __attribute__((packed)) fmha_bwd_dq_shuffle_args p3 _p9; unsigned int head_dim; p3 _p10; - const void *ptr_qseq; + const void* ptr_qseq; p2 _p11; - const void *ptr_qseq_padded; + const void* ptr_qseq_padded; p2 _p12; unsigned int max_seqlen_dq; p3 _p13; @@ -422,7 +422,7 @@ float fmha_bwd_v3(mha_bwd_traits t, const ck_tile::stream_config& s, const void* seqlen_q_padded = nullptr, const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); + bool is_v3_api_check = false); } namespace gfx950 { @@ -431,6 +431,6 @@ float fmha_bwd_v3(mha_bwd_traits t, const ck_tile::stream_config& s, const void* seqlen_q_padded = nullptr, const void* seqlen_k_padded = nullptr, - bool is_v3_api_check = false); + bool is_v3_api_check = false); } } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 62f14ab1c2..ab7d4edbf1 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -539,34 +539,36 @@ namespace py = pybind11; py::arg("rng_state") = std::nullopt, \ py::arg("gen") = std::nullopt); -#define MHA_VARLEN_BWD_ASM_PYBIND \ - m.def("fmha_v3_varlen_bwd", \ - &aiter::torch_itfs::fmha_v3_varlen_bwd, \ - py::arg("dout"), \ - py::arg("q"), \ - py::arg("k"), \ - py::arg("v"), \ - py::arg("out"), \ - py::arg("softmax_lse"), \ - py::arg("cu_seqlens_q"), \ - py::arg("cu_seqlens_k"), \ - py::arg("max_seqlen_q"), \ - py::arg("max_seqlen_k"), \ - py::arg("dropout_p"), \ - py::arg("softmax_scale"), \ - py::arg("zero_tensors"), \ - py::arg("is_causal"), \ - py::arg("window_size_left"), \ - py::arg("window_size_right"), \ - py::arg("deterministic"), \ - py::arg("is_v3_atomic_fp32"), \ - py::arg("how_v3_bf16_cvt"), \ - py::arg("dq") = std::nullopt, \ - py::arg("dk") = std::nullopt, \ - py::arg("dv") = std::nullopt, \ - py::arg("alibi_slopes") = std::nullopt, \ - py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); +#define MHA_VARLEN_BWD_ASM_PYBIND \ + m.def("fmha_v3_varlen_bwd", \ + &aiter::torch_itfs::fmha_v3_varlen_bwd, \ + py::arg("dout"), \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("out"), \ + py::arg("softmax_lse"), \ + py::arg("cu_seqlens_q"), \ + py::arg("cu_seqlens_k"), \ + py::arg("max_seqlen_q"), \ + py::arg("max_seqlen_k"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("zero_tensors"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("deterministic"), \ + py::arg("is_v3_atomic_fp32"), \ + py::arg("how_v3_bf16_cvt"), \ + py::arg("dq") = std::nullopt, \ + py::arg("dk") = std::nullopt, \ + py::arg("dv") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("rng_state") = std::nullopt, \ + py::arg("gen") = std::nullopt, \ + py::arg("cu_seqlens_q_padded") = std::nullopt, \ + py::arg("cu_seqlens_k_padded") = std::nullopt); #define MHA_BWD_PYBIND \ m.def("mha_bwd", \ @@ -657,32 +659,34 @@ namespace py = pybind11; py::arg("alibi_slopes") = std::nullopt, \ py::arg("gen") = std::nullopt); -#define MHA_VARLEN_BWD_PYBIND \ - m.def("mha_varlen_bwd", \ - &aiter::torch_itfs::mha_varlen_bwd, \ - py::arg("dout"), \ - py::arg("q"), \ - py::arg("k"), \ - py::arg("v"), \ - py::arg("out"), \ - py::arg("softmax_lse"), \ - py::arg("cu_seqlens_q"), \ - py::arg("cu_seqlens_k"), \ - py::arg("max_seqlen_q"), \ - py::arg("max_seqlen_k"), \ - py::arg("dropout_p"), \ - py::arg("softmax_scale"), \ - py::arg("zero_tensors"), \ - py::arg("is_causal"), \ - py::arg("window_size_left"), \ - py::arg("window_size_right"), \ - py::arg("deterministic"), \ - py::arg("dq") = std::nullopt, \ - py::arg("dk") = std::nullopt, \ - py::arg("dv") = std::nullopt, \ - py::arg("alibi_slopes") = std::nullopt, \ - py::arg("rng_state") = std::nullopt, \ - py::arg("gen") = std::nullopt); +#define MHA_VARLEN_BWD_PYBIND \ + m.def("mha_varlen_bwd", \ + &aiter::torch_itfs::mha_varlen_bwd, \ + py::arg("dout"), \ + py::arg("q"), \ + py::arg("k"), \ + py::arg("v"), \ + py::arg("out"), \ + py::arg("softmax_lse"), \ + py::arg("cu_seqlens_q"), \ + py::arg("cu_seqlens_k"), \ + py::arg("max_seqlen_q"), \ + py::arg("max_seqlen_k"), \ + py::arg("dropout_p"), \ + py::arg("softmax_scale"), \ + py::arg("zero_tensors"), \ + py::arg("is_causal"), \ + py::arg("window_size_left"), \ + py::arg("window_size_right"), \ + py::arg("deterministic"), \ + py::arg("dq") = std::nullopt, \ + py::arg("dk") = std::nullopt, \ + py::arg("dv") = std::nullopt, \ + py::arg("alibi_slopes") = std::nullopt, \ + py::arg("rng_state") = std::nullopt, \ + py::arg("gen") = std::nullopt, \ + py::arg("cu_seqlens_q_padded") = std::nullopt, \ + py::arg("cu_seqlens_k_padded") = std::nullopt); #define MOE_CK_2STAGES_PYBIND \ m.def("ck_moe_stage1", \ diff --git a/csrc/include/torch/mha_v3_varlen_bwd.h b/csrc/include/torch/mha_v3_varlen_bwd.h index 21b85fea92..81afc23f45 100644 --- a/csrc/include/torch/mha_v3_varlen_bwd.h +++ b/csrc/include/torch/mha_v3_varlen_bwd.h @@ -14,10 +14,6 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v] const at::Tensor& softmax_lse, // [b, hq, sq] const at::Tensor& cu_seqlens_q, // [b+1] const at::Tensor& cu_seqlens_k, // [b+1] - // FIXME: this two args currently not support on ck side - // and has no host code on aiter side - // const at::Tensor& cu_seqlens_q_padded, // [b+1] - // const at::Tensor& cu_seqlens_k_padded, // [b+1] const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, @@ -34,7 +30,9 @@ fmha_v3_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d_v] std::optional dv_, // [total_k, hk, d_v] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_); + std::optional gen_, + std::optional cu_seqlens_q_padded = std::nullopt, + std::optional cu_seqlens_k_padded = std::nullopt); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_bwd.h b/csrc/include/torch/mha_varlen_bwd.h index ea73564ea3..ac78ec2fb3 100644 --- a/csrc/include/torch/mha_varlen_bwd.h +++ b/csrc/include/torch/mha_varlen_bwd.h @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include namespace aiter { @@ -28,6 +28,9 @@ mha_varlen_bwd(const at::Tensor& dout, // [total_q, hq, d] std::optional dv, // [total_k, hk, d] std::optional alibi_slopes, // [hq] or [b, hq] std::optional rng_state, - std::optional gen); + std::optional gen, + std::optional cu_seqlens_q_padded, // [b+1] + std::optional cu_seqlens_k_padded // [b+1] +); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_bwd_kernels.cu b/csrc/py_itfs_ck/mha_bwd_kernels.cu index b0c1420cc0..2b20f11788 100644 --- a/csrc/py_itfs_ck/mha_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_bwd_kernels.cu @@ -145,9 +145,12 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, dv.data_ptr(), dbias_ptr, dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q - nullptr, // seqstart_k + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, @@ -155,7 +158,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, seqlen_k, // max_seqlen_k hdim_q, // hdim_q hdim_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, stride_q, diff --git a/csrc/py_itfs_ck/mha_fwd_kernels.cu b/csrc/py_itfs_ck/mha_fwd_kernels.cu index 1bdfde270b..d53678360d 100644 --- a/csrc/py_itfs_ck/mha_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_fwd_kernels.cu @@ -97,20 +97,19 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - cu_seqlen_q_ptr, - cu_seqlen_kv_ptr, - nullptr, // seqstart_q - nullptr, // seqstart_k - nullptr, - nullptr, // seqstart_padded_q_ptr - nullptr, // seqstart_padded_k_ptr + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr + nullptr, // seqlen_k_ptr + cu_seqlen_q_ptr, // cu_seqlen_q_ptr + cu_seqlen_kv_ptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, seqlen_q, // max_seqlen_q d, // hdim_q d_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, // scale_s 1, // scale_p @@ -139,7 +138,7 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), - 0, + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu index c93c15ddeb..6b0c6076bd 100644 --- a/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu @@ -23,8 +23,10 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_k, + std::optional &cu_seqlens_q_padded, + std::optional &cu_seqlens_k_padded, std::optional &alibi_slopes_, const at::Tensor out, const at::Tensor softmax_lse, @@ -110,6 +112,25 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } + const void* seqstart_k_ptr = nullptr; + const void* seqstart_q_ptr = nullptr; + const void* cu_seqlen_k_ptr = nullptr; + const void* cu_seqlen_q_ptr = nullptr; + + if (cu_seqlens_k_padded.has_value()) { + seqstart_k_ptr = cu_seqlens_k_padded.value().data_ptr(); + cu_seqlen_k_ptr = cu_seqlens_k.data_ptr(); + } else { + seqstart_k_ptr = cu_seqlens_k.data_ptr(); + } + + if (cu_seqlens_q_padded.has_value()) { + seqstart_q_ptr = cu_seqlens_q_padded.value().data_ptr(); + cu_seqlen_q_ptr = cu_seqlens_q.data_ptr(); + } else { + seqstart_q_ptr = cu_seqlens_q.data_ptr(); + } + return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -124,9 +145,12 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k - nullptr, // seqlen_k_ptr + seqstart_q_ptr, // seqstart_q_ptr (physical cumulative) + seqstart_k_ptr, // seqstart_k_ptr (physical cumulative) + nullptr, // seqlen_q_ptr (per-sequence logical) + nullptr, // seqlen_k_ptr (per-sequence logical) + cu_seqlen_q_ptr, // cu_seqlen_q_ptr (cumulative logical, not used in CK backend for now) + cu_seqlen_k_ptr, // cu_seqlen_k_ptr (cumulative logical, not used in CK backend for now) total_q, total_k, b, @@ -134,7 +158,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, max_seqlen_k, // max_seqlen_k hdim_q, // hdim_q hdim_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, stride_q, @@ -207,7 +231,10 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] std::optional dv_, // [total_k, hk, d_v] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional cu_seqlens_q_padded, // [b+1] + std::optional cu_seqlens_k_padded // [b+1] + ) { if (is_causal) { window_size_right = 0; } @@ -224,7 +251,14 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - + if (cu_seqlens_q_padded.has_value()) { + TORCH_CHECK(cu_seqlens_q_padded.value().dtype() == torch::kInt32, "cu_seqlens_q_padded must have dtype int32"); + CHECK_CONTIGUOUS(cu_seqlens_q_padded.value()); + } + if (cu_seqlens_k_padded.has_value()) { + TORCH_CHECK(cu_seqlens_k_padded.value().dtype() == torch::kInt32, "cu_seqlens_k_padded must have dtype int32"); + CHECK_CONTIGUOUS(cu_seqlens_k_padded.value()); + } std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); @@ -314,7 +348,7 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_d = torch::empty({batch_size, num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (!deterministic) { @@ -383,6 +417,8 @@ mha_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v] v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_q_padded, + cu_seqlens_k_padded, alibi_slopes_, out, softmax_lse, diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index 7d386cdbf2..54ca060dc9 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -96,31 +96,25 @@ mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, bias_ptr = alibi_slopes.data_ptr(); stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } - - // Validate padded seqstart arrays if provided: shape [b+1], 1D, contiguous, int32/int64, monotonic - auto validate_and_maybe_convert = [&](std::optional &opt_seqstarts, - const char *name) -> const ck_tile::index_t* { - if (!opt_seqstarts.has_value()) return nullptr; - const at::Tensor &t = opt_seqstarts.value(); - CHECK_DEVICE(t); - TORCH_CHECK(t.dim() == 1, name, " must be 1D"); - TORCH_CHECK(t.numel() == b + 1, name, " must have length batch+1"); - TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); - TORCH_CHECK(t.dtype() == torch::kInt32, name, " must be int32, actual: ", t.dtype()); - auto ptr = reinterpret_cast(t.data_ptr()); - auto acc = t.index({0}).item(); - TORCH_CHECK(acc == 0, name, " first element must be 0"); - auto data_ptr32 = t.data_ptr(); - for (int i = 1; i < t.numel(); ++i) { - int v = data_ptr32[i]; - TORCH_CHECK(v >= acc, name, " must be non-decreasing"); - acc = v; - } - return ptr; - }; + + const void* seqstart_k_ptr = nullptr; + const void* seqstart_q_ptr = nullptr; + const void* cu_seqlen_k_ptr = nullptr; + const void* cu_seqlen_q_ptr = nullptr; + + if (cu_seqlens_k_padded_.has_value()) { + seqstart_k_ptr = cu_seqlens_k_padded_.value().data_ptr(); + cu_seqlen_k_ptr = cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr; + } else { + seqstart_k_ptr = cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr; + } - const ck_tile::index_t *seqstart_padded_q_ptr = validate_and_maybe_convert(cu_seqlens_q_padded_, "cu_seqlens_q_padded"); - const ck_tile::index_t *seqstart_padded_k_ptr = validate_and_maybe_convert(cu_seqlens_k_padded_, "cu_seqlens_k_padded"); + if (cu_seqlens_q_padded_.has_value()) { + seqstart_q_ptr = cu_seqlens_q_padded_.value().data_ptr(); + cu_seqlen_q_ptr = cu_seqlens_q.data_ptr(); + } else { + seqstart_q_ptr = cu_seqlens_q.data_ptr(); + } return mha_fwd_args{q.data_ptr(), k.data_ptr(), @@ -129,20 +123,19 @@ mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // cu_seqlen_q_ptr (batch mode only) - nullptr, // cu_seqlen_kv_ptr (batch mode only) - cu_seqlens_q.data_ptr(), // seqstart_q - cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr, // seqstart_k - seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_kpads - seqstart_padded_q_ptr, - seqstart_padded_k_ptr, + seqstart_q_ptr, // seqstart_q_ptr (cumulative physical with padding) + seqstart_k_ptr, // seqstart_k_ptr (cumulative physical with padding) + nullptr, // seqlen_q_ptr (per-sequence logical, alternative to cu_seqlen_q_ptr) + seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_k_ptr (per-sequence logical K lengths) + cu_seqlen_q_ptr, // cu_seqlen_q_ptr + cu_seqlen_k_ptr, // cu_seqlen_k_ptr total_q, total_k, b, max_seqlen_q, d, // hdim_q d_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, // scale_s 1, // scale_p @@ -624,9 +617,7 @@ mha_varlen_fwd( bias_type, has_lse, false, // use_ext_asm - 1, // how_v3_bf16_cvt - args.seqstart_padded_q_ptr, - args.seqstart_padded_k_ptr); + 1); // how_v3_bf16_cvt TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } } diff --git a/csrc/py_itfs_cu/asm_mha_bwd.cu b/csrc/py_itfs_cu/asm_mha_bwd.cu index 730f49d511..037394791e 100644 --- a/csrc/py_itfs_cu/asm_mha_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_bwd.cu @@ -119,9 +119,12 @@ fmha_bwd_args get_asm_fmha_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - nullptr, // seqstart_q - nullptr, // seqstart_k - nullptr, // seqlen_k_ptr + nullptr, // seqstart_q_ptr (batch mode) + nullptr, // seqstart_k_ptr (batch mode) + nullptr, // seqlen_q_ptr (batch mode) + nullptr, // seqlen_k_ptr (batch mode) + nullptr, // cu_seqlen_q_ptr (batch mode) + nullptr, // cu_seqlen_k_ptr (batch mode) seqlen_q, seqlen_k, b, @@ -129,7 +132,7 @@ fmha_bwd_args get_asm_fmha_bwd_args(const mask_info &mask, seqlen_k, // max_seqlen_k hdim_q, // hdim_q hdim_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, stride_q, diff --git a/csrc/py_itfs_cu/asm_mha_fwd.cu b/csrc/py_itfs_cu/asm_mha_fwd.cu index 62354ded84..f5cbf24d86 100644 --- a/csrc/py_itfs_cu/asm_mha_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_fwd.cu @@ -92,20 +92,19 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, // cu_seqlen_q_ptr - nullptr, // cu_seqlen_kv_ptr - nullptr, // seqstart_q - nullptr, // seqstart_k + nullptr, // seqstart_q_ptr + nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr - nullptr, // seqstart_padded_q_ptr - nullptr, // seqstart_padded_k_ptr + nullptr, // cu_seqlen_q_ptr + nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, b, seqlen_q, // max_seqlen_q d, // hdim_q d_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, // scale_s 1, // scale_p @@ -134,7 +133,7 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, mask.left, mask.right, static_cast(mask.type), - 0, + 0, // min_seqlen_q p_dropout, has_dropout_randval, drop_seed_offset}; diff --git a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu index 0dad6f0c04..04b6dad3a7 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_bwd.cu @@ -23,8 +23,10 @@ fmha_bwd_args get_asm_fmha_varlen_bwd_args(const mask_info &mask, const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor seqlens_q, - const at::Tensor seqlens_k, + const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_k, + std::optional &cu_seqlens_q_padded, + std::optional &cu_seqlens_k_padded, std::optional &alibi_slopes_, const at::Tensor out, const at::Tensor softmax_lse, @@ -123,6 +125,25 @@ fmha_bwd_args get_asm_fmha_varlen_bwd_args(const mask_info &mask, stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; } + const void* seqstart_q_ptr = nullptr; + const void* seqstart_k_ptr = nullptr; + const void* cu_seqlen_q_ptr = nullptr; + const void* cu_seqlen_k_ptr = nullptr; + + if (cu_seqlens_k_padded.has_value()) { + seqstart_k_ptr = cu_seqlens_k_padded.value().data_ptr(); + cu_seqlen_k_ptr = cu_seqlens_k.data_ptr(); + } else { + seqstart_k_ptr = cu_seqlens_k.data_ptr(); + } + + if (cu_seqlens_q_padded.has_value()) { + seqstart_q_ptr = cu_seqlens_q_padded.value().data_ptr(); + cu_seqlen_q_ptr = cu_seqlens_q.data_ptr(); + } else { + seqstart_q_ptr = cu_seqlens_q.data_ptr(); + } + return fmha_bwd_args{q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -137,9 +158,12 @@ fmha_bwd_args get_asm_fmha_varlen_bwd_args(const mask_info &mask, dv.data_ptr(), nullptr, // dbias dq_acc.data_ptr(), // dq_acc - seqlens_q.data_ptr(), // seqstart_q - seqlens_k.data_ptr(), // seqstart_k + seqstart_q_ptr, // seqstart_q + seqstart_k_ptr, // seqstart_k + nullptr, // seqlen_q_ptr nullptr, // seqlen_k_ptr + cu_seqlen_q_ptr, // cu_seqlen_q_ptr + cu_seqlen_k_ptr, // cu_seqlen_k_ptr total_q, total_k, b, @@ -206,10 +230,6 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v const at::Tensor &softmax_lse, // [b, hq, sq] const at::Tensor &cu_seqlens_q, // [b+1] const at::Tensor &cu_seqlens_k, // [b+1] - // FIXME: this two args currently not support on ck side - // and has no host code on aiter side - // const at::Tensor& cu_seqlens_q_padded, // [b+1] - // const at::Tensor& cu_seqlens_k_padded, // [b+1] const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, @@ -226,7 +246,9 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v std::optional dv_, // [total_k, hk, d_v] std::optional alibi_slopes_, // [hq] or [b, hq] std::optional rng_state_, - std::optional gen_) + std::optional gen_, + std::optional cu_seqlens_q_padded, + std::optional cu_seqlens_k_padded) { if (is_causal) { window_size_right = 0; } @@ -243,7 +265,14 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32"); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32"); - + if (cu_seqlens_q_padded.has_value()) { + TORCH_CHECK(cu_seqlens_q_padded.value().dtype() == torch::kInt32, "cu_seqlens_q_padded must have dtype int32"); + CHECK_CONTIGUOUS(cu_seqlens_q_padded.value()); + } + if (cu_seqlens_k_padded.has_value()) { + TORCH_CHECK(cu_seqlens_k_padded.value().dtype() == torch::kInt32, "cu_seqlens_k_padded must have dtype int32"); + CHECK_CONTIGUOUS(cu_seqlens_k_padded.value()); + } std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16"; CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); @@ -333,7 +362,7 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard{q.device()}; auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_d = torch::empty({batch_size, num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (!deterministic) { @@ -404,6 +433,8 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v v, cu_seqlens_q, cu_seqlens_k, + cu_seqlens_q_padded, + cu_seqlens_k_padded, alibi_slopes_, out, softmax_lse, @@ -418,6 +449,7 @@ fmha_v3_varlen_bwd(const at::Tensor &dout, // [total_q, hq, d_v drop_seed_offset, is_v3_atomic_fp32); + float t = aiter::mha_bwd(args, stream_config, q_dtype_str, diff --git a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu index 07cbf20a08..faa1a662cb 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu @@ -101,20 +101,19 @@ mha_fwd_args get_asm_mha_varlen_fwd_args(bool has_lse, has_dropout_randval ? dropout_randval.data_ptr() : nullptr, has_lse ? softmax_lse.data_ptr() : nullptr, out.data_ptr(), - nullptr, - nullptr, - cu_seqlens_q.data_ptr(), // seqstart_q - cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr, // seqstart_k - seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_kpads - nullptr, - nullptr, + cu_seqlens_q.data_ptr(), // seqstart_q_ptr (cumulative physical) + cu_seqlens_k.has_value() ? cu_seqlens_k.value().data_ptr() : nullptr, // seqstart_k_ptr + nullptr, // seqlen_q_ptr (per-sequence logical, not used here) + seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_k_ptr + nullptr, // cu_seqlen_q_ptr (not used in this mode) + nullptr, // cu_seqlen_k_ptr (not used in this mode) total_q, total_k, b, max_seqlen_q, d, // hdim_q d_v, // hdim_v - h, // nhead + h, // nhead_q h_k, // nhead_k softmax_scale, // scale_s 1, // scale_p diff --git a/hsa/gfx942/fmha_v3_bwd/codegen.py b/hsa/gfx942/fmha_v3_bwd/codegen.py index 3749ce8558..b420509976 100644 --- a/hsa/gfx942/fmha_v3_bwd/codegen.py +++ b/hsa/gfx942/fmha_v3_bwd/codegen.py @@ -809,14 +809,23 @@ class fmha_bwd_v3_kernel args.ptr_do = a.do_ptr; args.ptr_lse = a.lse_ptr; args.ptr_d = a.d_ptr; - args.ptr_qseq = a.seqstart_q_ptr; - args.ptr_kseq = a.seqstart_k_ptr; - args.ptr_qseq_padded = seqlen_q_padded == nullptr - ? a.seqstart_q_ptr - : seqlen_q_padded; - args.ptr_kseq_padded = seqlen_k_padded == nullptr - ? a.seqstart_k_ptr - : seqlen_k_padded; + + if (a.cu_seqlen_k_ptr && a.seqstart_k_ptr) { + args.ptr_kseq_padded = a.seqstart_k_ptr; + args.ptr_kseq = a.cu_seqlen_k_ptr; + } else { + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_kseq_padded = a.seqstart_k_ptr; + } + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + args.ptr_qseq_padded = a.seqstart_q_ptr; + args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_qseq_padded = a.seqstart_q_ptr; + } + args.scalar = a.scale; args.log2e = ck_tile::log2e_v; args.ratio = a.nhead_q / a.nhead_k; diff --git a/hsa/gfx950/fmha_v3_bwd/codegen.py b/hsa/gfx950/fmha_v3_bwd/codegen.py index 7c51f1b0e6..17fb23267d 100644 --- a/hsa/gfx950/fmha_v3_bwd/codegen.py +++ b/hsa/gfx950/fmha_v3_bwd/codegen.py @@ -947,14 +947,23 @@ class fmha_bwd_v3_kernel args.ptr_do = a.do_ptr; args.ptr_lse = a.lse_ptr; args.ptr_d = a.d_ptr; - args.ptr_qseq = a.seqstart_q_ptr; - args.ptr_kseq = a.seqstart_k_ptr; - args.ptr_qseq_padded = seqlen_q_padded == nullptr - ? a.seqstart_q_ptr - : seqlen_q_padded; - args.ptr_kseq_padded = seqlen_k_padded == nullptr - ? a.seqstart_k_ptr - : seqlen_k_padded; + + if (a.cu_seqlen_k_ptr && a.seqstart_k_ptr) { + args.ptr_kseq_padded = a.seqstart_k_ptr; + args.ptr_kseq = a.cu_seqlen_k_ptr; + } else { + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_kseq_padded = a.seqstart_k_ptr; + } + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + args.ptr_qseq_padded = a.seqstart_q_ptr; + args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_qseq_padded = a.seqstart_q_ptr; + } + args.scalar = a.scale; args.log2e = ck_tile::log2e_v; args.ratio = a.nhead_q / a.nhead_k; @@ -1104,14 +1113,22 @@ class fmha_bwd_v3_kernel args.BAs_dv = a.batch_stride_dv * 2; args.Seqs_dv = a.stride_dv * 2; args.Hs_lsed = a.nhead_stride_lsed * 4; - args.ptr_qseq = a.seqstart_q_ptr; - args.ptr_kseq = a.seqstart_k_ptr; - args.ptr_qseq_padded = seqlen_q_padded == nullptr - ? a.seqstart_q_ptr - : seqlen_q_padded; - args.ptr_kseq_padded = seqlen_k_padded == nullptr - ? a.seqstart_k_ptr - : seqlen_k_padded; + + if (a.cu_seqlen_k_ptr && a.seqstart_k_ptr) { + args.ptr_kseq_padded = a.seqstart_k_ptr; + args.ptr_kseq = a.cu_seqlen_k_ptr; + } else { + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_kseq_padded = a.seqstart_k_ptr; + } + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + args.ptr_qseq_padded = a.seqstart_q_ptr; + args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_qseq_padded = a.seqstart_q_ptr; + } args.max_seqlen_dq = a.max_seqlen_q; auto traits = fmha_bwd_v3_traits{a.batch, @@ -1178,14 +1195,22 @@ class fmha_bwd_v3_kernel args.BAs_dv = a.batch_stride_dv * 2; args.Seqs_dv = a.stride_dv * 2; args.Hs_lsed = a.nhead_stride_lsed * 4; - args.ptr_qseq = a.seqstart_q_ptr; - args.ptr_kseq = a.seqstart_k_ptr; - args.ptr_qseq_padded = seqlen_q_padded == nullptr - ? a.seqstart_q_ptr - : seqlen_q_padded; - args.ptr_kseq_padded = seqlen_k_padded == nullptr - ? a.seqstart_k_ptr - : seqlen_k_padded; + + if (a.cu_seqlen_k_ptr && a.seqstart_k_ptr) { + args.ptr_kseq_padded = a.seqstart_k_ptr; + args.ptr_kseq = a.cu_seqlen_k_ptr; + } else { + args.ptr_kseq = a.seqstart_k_ptr; + args.ptr_kseq_padded = a.seqstart_k_ptr; + } + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + args.ptr_qseq_padded = a.seqstart_q_ptr; + args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + args.ptr_qseq = a.seqstart_q_ptr; + args.ptr_qseq_padded = a.seqstart_q_ptr; + } args.max_seqlen_dq = (a.max_seqlen_q + 15) / 16 * 16; fmha_bwd_dq_shuffle_args dq_shuffule_args; @@ -1200,10 +1225,15 @@ class fmha_bwd_v3_kernel dq_shuffule_args.Seqs_dq = a.stride_dq * 2; dq_shuffule_args.seqlen_q = a.seqlen_q; dq_shuffule_args.head_dim = a.hdim_q; - dq_shuffule_args.ptr_qseq = a.seqstart_q_ptr; - dq_shuffule_args.ptr_qseq_padded = seqlen_q_padded == nullptr - ? a.seqstart_q_ptr - : seqlen_q_padded; + + if (a.cu_seqlen_q_ptr && a.seqstart_q_ptr) { + dq_shuffule_args.ptr_qseq_padded = a.seqstart_q_ptr; + dq_shuffule_args.ptr_qseq = a.cu_seqlen_q_ptr; + } else { + dq_shuffule_args.ptr_qseq = a.seqstart_q_ptr; + dq_shuffule_args.ptr_qseq_padded = a.seqstart_q_ptr; + } + dq_shuffule_args.max_seqlen_dq = (a.max_seqlen_q + 15) / 16 * 16; auto traits = fmha_bwd_v3_traits{a.batch, diff --git a/op_tests/cpp/mha/benchmark_mha_bwd.cpp b/op_tests/cpp/mha/benchmark_mha_bwd.cpp index e4b5e5413a..aaee36f7e0 100644 --- a/op_tests/cpp/mha/benchmark_mha_bwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_bwd.cpp @@ -395,8 +395,8 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor dq_acc_host( std::array{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}); - ck_tile::HostTensor dq_acc_host_a16(std::array{ - nsplits, batch, nhead, a16_dq_acc_seq, a16_dq_acc_hdim}); + ck_tile::HostTensor dq_acc_host_a16( + std::array{nsplits, batch, nhead, a16_dq_acc_seq, a16_dq_acc_hdim}); if(init_method == 0) { @@ -579,6 +579,9 @@ bool run(const ck_tile::ArgParser& arg_parser) seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), nullptr, + nullptr, + nullptr, + nullptr, shape_seqlen_q, shape_seqlen_k, batch, diff --git a/op_tests/cpp/mha/benchmark_mha_fwd.cpp b/op_tests/cpp/mha/benchmark_mha_fwd.cpp index 9b5d2cb2f9..afcc632f79 100644 --- a/op_tests/cpp/mha/benchmark_mha_fwd.cpp +++ b/op_tests/cpp/mha/benchmark_mha_fwd.cpp @@ -16,7 +16,6 @@ #include #include - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -365,15 +364,18 @@ bool run(const ck_tile::ArgParser& arg_parser) #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - auto [seqlen_qs, seqlen_ks, seqlen_kpads] = + auto [seqlen_qs, seqlen_ks, seqlen_qpads, seqlen_kpads] = generate_missing_seqlens(mode, batch, arg_parser.get_int_vec("s"), arg_parser.get_int_vec("s_k"), + {}, // q_pad_val arg_parser.get_int_vec("s_kpad"), /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, need_append_kvcache, random_engine); + ck_tile::ignore = seqlen_qpads; + // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -657,7 +659,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); - ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(bias_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); } else if(init_method == "ni") { @@ -666,7 +669,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(knew_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(v_host); ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(vnew_host); - ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}(bias_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, next_seed()}( + bias_host); } else if(init_method == "uf" || init_method == "1") { @@ -700,14 +704,17 @@ bool run(const ck_tile::ArgParser& arg_parser) { ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, next_seed()}(q_host); ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}(knew_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, next_seed()}( + knew_host); ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}(vnew_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, next_seed()}( + vnew_host); // bias_fp8 = qscale_bias * bias_fp32 float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); // Assume bias is in [-1.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, next_seed()}(bias_host); + ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, next_seed()}( + bias_host); } if(bias.type == bias_enum::alibi) { diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 4ff898ffca..fa2cc63058 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -259,6 +259,7 @@ def run_ck_seq_padding( causal=False, window_size=(-1, -1), alibi_slopes=None, + dout=None, ): """Run CK varlen forward with physically padded inputs.""" @@ -298,9 +299,9 @@ def _flatten(tensor, padded_lens): pieces.append(tensor[i, : padded_lens[i]]) return torch.cat(pieces, dim=0) - q_flat = _flatten(q, q_padded_lens) - k_flat = _flatten(k, k_padded_lens) - v_flat = _flatten(v, k_padded_lens) + q_flat = _flatten(q, q_padded_lens).requires_grad_(True) + k_flat = _flatten(k, k_padded_lens).requires_grad_(True) + v_flat = _flatten(v, k_padded_lens).requires_grad_(True) outputs = aiter.flash_attn_varlen_func( q_flat, @@ -315,7 +316,7 @@ def _flatten(tensor, padded_lens): window_size=window_size, alibi_slopes=alibi_slopes, deterministic=deterministic, - return_lse=False, + return_lse=True, return_attn_probs=False, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_k_padded=cu_seqlens_k_padded, @@ -332,7 +333,44 @@ def _flatten(tensor, padded_lens): out_batch[:keep] = out_flat[start : start + keep] out_batches.append(out_batch) - return torch.stack(out_batches, dim=0) + out_stack = torch.stack(out_batches, dim=0) + + if dout is None: + return out_stack + + dout_flat = _flatten(dout, q_padded_lens) + + dq_flat, dk_flat, dv_flat = torch.autograd.grad( + outputs=out_flat, + inputs=(q_flat, k_flat, v_flat), + grad_outputs=dout_flat, + create_graph=False, + retain_graph=True, + allow_unused=True, + ) + + def _unflatten(flat, padded_lens, max_padded_len, head_dim, value_dim): + pieces = [] + start = 0 + for i in range(batch_size): + end = start + padded_lens[i] + t = torch.zeros( + max_padded_len, + head_dim, + value_dim, + device=flat.device, + dtype=flat.dtype, + ) + t[: padded_lens[i]] = flat[start:end] + pieces.append(t) + start = end + return torch.stack(pieces, dim=0) + + dq = _unflatten(dq_flat, q_padded_lens, max(q_padded_lens), nheads, d) + dk = _unflatten(dk_flat, k_padded_lens, max(k_padded_lens), k.size(2), d) + dv = _unflatten(dv_flat, k_padded_lens, max(k_padded_lens), k.size(2), d_v) + + return out_stack, dq, dk, dv @pytest.mark.parametrize("input_layout", ["BSHD", "KVPACKED"]) @@ -612,7 +650,7 @@ def flash_attn_varlen_func_benchmark( @pytest.mark.parametrize("deterministic", [True, False]) @pytest.mark.parametrize( "padding_scenario", - ["mixed", "q_only", "k_only", "no_padding", "q_len_1", "k_len_1"], + ["mixed", "q_only", "k_only", "no_padding"], ) @pytest.mark.parametrize("dtype", [dtypes.fp16, dtypes.bf16]) @pytest.mark.parametrize( @@ -686,10 +724,6 @@ def test_varlen_flash_attn_seq_padding( elif padding_scenario == "no_padding": q_actual_lens = q_padded_lens k_actual_lens = k_padded_lens - elif padding_scenario == "q_len_1": - q_actual_lens = [1] * batch_size - elif padding_scenario == "k_len_1": - k_actual_lens = [1] * batch_size q_s = max(q_padded_lens) k_s = max(k_padded_lens) @@ -710,6 +744,10 @@ def test_varlen_flash_attn_seq_padding( k_actual_lens[i], nheads_k, d_v, device=device, dtype=dtype ) + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + query_padding_mask = torch.arange(q_s, device=device).unsqueeze(0).expand( batch_size, -1 ) < torch.tensor(q_actual_lens, device=device).unsqueeze(1) @@ -717,7 +755,8 @@ def test_varlen_flash_attn_seq_padding( batch_size, -1 ) < torch.tensor(k_actual_lens, device=device).unsqueeze(1) - out_ck = run_ck_seq_padding( + dout = torch.randn_like(q, dtype=q.dtype, device=device) + out_ck, dq_ck, dk_ck, dv_ck = run_ck_seq_padding( q, k, v, @@ -728,9 +767,10 @@ def test_varlen_flash_attn_seq_padding( deterministic, causal=True, window_size=window_size, + dout=dout, ) - out_ref = run_torch( + out_ref, dq_ref, dk_ref, dv_ref = run_torch( q, k, v, @@ -738,14 +778,14 @@ def test_varlen_flash_attn_seq_padding( key_padding_mask, bias=None, alibi_slopes=None, - dout=None, + dout=dout, dropout_p=0.0, dropout_mask=None, causal=True, window_size=window_size, ) - out_pt = run_torch( + out_pt, dq_pt, dk_pt, dv_pt = run_torch( q, k, v, @@ -753,7 +793,7 @@ def test_varlen_flash_attn_seq_padding( key_padding_mask, bias=None, alibi_slopes=None, - dout=None, + dout=dout, dropout_p=0.0, dropout_mask=None, causal=True, @@ -785,6 +825,74 @@ def test_varlen_flash_attn_seq_padding( ) assert out_diff <= out_tol + def _mask_grad(tensor, lens): + masked = tensor.clone() + for i, length in enumerate(lens): + masked[i, length:] = 0 + return masked + + dq_ref_masked = _mask_grad(dq_ref, q_actual_lens) + dq_pt_masked = _mask_grad(dq_pt, q_actual_lens) + dq_ck_masked = _mask_grad(dq_ck, q_actual_lens) + + dk_ref_masked = _mask_grad(dk_ref, k_actual_lens) + dk_pt_masked = _mask_grad(dk_pt, k_actual_lens) + dk_ck_masked = _mask_grad(dk_ck, k_actual_lens) + + dv_ref_masked = _mask_grad(dv_ref, k_actual_lens) + dv_pt_masked = _mask_grad(dv_pt, k_actual_lens) + dv_ck_masked = _mask_grad(dv_ck, k_actual_lens) + + dq_pt_diff = (dq_pt_masked - dq_ref_masked).abs().max().item() + dk_pt_diff = (dk_pt_masked - dk_ref_masked).abs().max().item() + dv_pt_diff = (dv_pt_masked - dv_ref_masked).abs().max().item() + print(f"dQ Pytorch max diff (masked): {dq_pt_diff}") + print(f"dK Pytorch max diff (masked): {dk_pt_diff}") + print(f"dV Pytorch max diff (masked): {dv_pt_diff}") + + dq_tol = max(10 * dq_pt_diff, 0.01) + dk_tol = max(10 * dk_pt_diff, 0.01) + dv_tol = max(10 * dv_pt_diff, 0.01) + + dq_ck_diff = (dq_ck_masked - dq_ref_masked).abs().max().item() + dk_ck_diff = (dk_ck_masked - dk_ref_masked).abs().max().item() + dv_ck_diff = (dv_ck_masked - dv_ref_masked).abs().max().item() + + print(f"dQ CK max diff (masked): {dq_ck_diff}") + print(f"dK CK max diff (masked): {dk_ck_diff}") + print(f"dV CK max diff (masked): {dv_ck_diff}") + + assert dq_ck_diff <= dq_tol + assert dk_ck_diff <= dk_tol + assert dv_ck_diff <= dv_tol + + +@benchmark() +def varlen_flash_attn_seq_padding_benchmark( + batch_size, + mha_type, + deterministic, + padding_scenario, + dtype, + d, + d_v, + seqlen_q, + seqlen_k, + local, +): + return test_varlen_flash_attn_seq_padding( + batch_size=batch_size, + mha_type=mha_type, + deterministic=deterministic, + padding_scenario=padding_scenario, + dtype=dtype, + d=d, + d_v=d_v, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + local=local, + ) + l_causal = [False, True] l_local = [False, True] @@ -959,11 +1067,29 @@ def test_varlen_flash_attn_seq_padding( args.input_layout, ) collected.append(ret) - test_varlen_flash_attn_seq_padding( + + # Run seq_padding benchmark + padding_collected = [] + for ( + dtype, + (dim_qk, dim_v), + mha_type, + deterministic, + padding_scenario, + local, + ) in itertools.product( + args.dtype, + args.d_qk_v, + args.mha_type, + l_deterministic, + ["mixed", "q_only", "k_only", "no_padding"], + l_local, + ): + ret = varlen_flash_attn_seq_padding_benchmark( args.batch_size, mha_type, deterministic, - "mixed", + padding_scenario, dtypes.d_dtypes[dtype], dim_qk, dim_v, @@ -971,6 +1097,10 @@ def test_varlen_flash_attn_seq_padding( seqlen_k, local, ) + padding_collected.append(ret) df = pd.DataFrame(collected) aiter.logger.info(f"mha_varlen summary:\n{df}") + + df_padding = pd.DataFrame(padding_collected) + aiter.logger.info(f"mha_varlen_seq_padding summary:\n{df_padding}")