Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3fa3733
add page size 16 to test and op
ltqin Sep 8, 2025
113217e
add num_total_pages to kernel parameter
ltqin Sep 9, 2025
6e2d9e4
add is_sglang parameter
ltqin Sep 9, 2025
7a463b7
chang is_sglang to is_sglang_layout
ltqin Sep 9, 2025
ee72e04
kv last page size=16 pass
ltqin Sep 12, 2025
ae459b0
pass kv_last_page_lens to kernel
ltqin Sep 13, 2025
b25cee7
add parameters check before calling kernel
ltqin Sep 15, 2025
93754f4
change kv layout to [page_num, page_size, nhead, hdim]
ltqin Sep 17, 2025
8c52122
adopt the changes of struct fmha_fwd_batch_prefill_traits
Jeff-Huang Dec 13, 2025
9d7cd3f
change kv cache memory layout to [num_blocks, num_kv_heads, head_size…
Jeff-Huang Dec 19, 2025
e0cb1ea
[FMHA] Integrate vLLM block table support and enforce vectorized KV l…
Jeff-Huang Dec 24, 2025
ac28e9d
update CK
Jeff-Huang Dec 30, 2025
9d69a01
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Dec 30, 2025
688b141
update ck
Jeff-Huang Dec 30, 2025
0c9c886
adopt api changes from fmha_batch_prefill_traits
Jeff-Huang Dec 30, 2025
c75fee4
add support for linear kv cache layout
Jeff-Huang Dec 31, 2025
d144a76
update api
Jeff-Huang Dec 31, 2025
d727a92
Refactor the test code by gathering the different test functions into…
Jeff-Huang Dec 31, 2025
7642e79
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Dec 31, 2025
2917917
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 5, 2026
b1f452c
update ck
Jeff-Huang Jan 5, 2026
ed5f66a
update ck
Jeff-Huang Jan 5, 2026
f5cc627
Add profile measurements for batch prefill function
Jeff-Huang Jan 6, 2026
c7dd47f
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 7, 2026
9e10ffc
update ck
Jeff-Huang Jan 7, 2026
6a06de9
fix style
Jeff-Huang Jan 7, 2026
ae12e04
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 7, 2026
db5f333
fix style
Jeff-Huang Jan 7, 2026
44a5cc7
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 8, 2026
4de0de3
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 9, 2026
1ed076f
[FMHA] Support 3D linear layout (page_size=1) and non-contiguous KV t…
Jeff-Huang Jan 10, 2026
ec79599
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 12, 2026
ba88187
Merge branch 'main' into batch_prefill_page_size_16_rebase
Jeff-Huang Jan 13, 2026
e7af363
update ck
Jeff-Huang Jan 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/composable_kernel
Submodule composable_kernel updated 38 files
+141 −75 example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py
+90 −14 example/ck_tile/01_fmha/fmha_fwd.hpp
+3 −2 example/ck_tile/18_flatmm/CMakeLists.txt
+1 −1 example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
+28 −3 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
+4 −3 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
+12 −7 example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc
+3 −2 example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
+4 −2 example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in
+6 −4 example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
+35 −15 experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp
+118 −0 experimental/builder/include/ck_tile/builder/factory/reference_common.hpp
+249 −0 experimental/builder/include/ck_tile/builder/factory/reference_factory.hpp
+191 −0 experimental/builder/include/ck_tile/builder/reflect/instance_traits_reference.hpp
+2 −1 experimental/builder/include/ck_tile/builder/types.hpp
+22 −13 experimental/builder/test/CMakeLists.txt
+9 −0 experimental/builder/test/impl/conv_algorithm_types.hpp
+1,031 −0 experimental/builder/test/validation/test_reference_execution.cpp
+117 −0 experimental/builder/test/validation/test_reference_instance_traits.cpp
+0 −1 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp
+1 −3 include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
+17 −15 include/ck_tile/host/reference/reference_moe_gemm.hpp
+20 −17 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
+36 −20 include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp
+1 −0 include/ck_tile/ops/fmha.hpp
+32 −0 include/ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp
+174 −93 include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp
+376 −234 include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
+64 −0 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+42 −0 include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+16 −16 ...or_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_bilinear_instance.hpp
+12 −12 ...ensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_scale_instance.hpp
+6 −5 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_bilinear.hpp
+5 −4 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_scale.hpp
+3 −3 ..._weight_bilinear/wmma/device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+2 −2 ...3d_bwd_weight_scale/wmma/device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
+1 −0 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp
+1 −0 test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp
46 changes: 38 additions & 8 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,9 @@ def cmdGenFunc_mha_batch_prefill(
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
gen: Optional[Generator] = None,
kv_last_page_lens: Optional[Tensor] = None,
block_table: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
):
# causal=true is the same as causal=false in this case
causal = is_causal
Expand Down Expand Up @@ -2606,15 +2609,21 @@ def mha_batch_prefill_fake_tensors(
return_softmax_lse: bool,
return_dropout_randval: bool,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
gen: Optional[Generator] = None,
kv_last_page_lens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
seqlen_k: Optional[torch.Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if k.dim() != 5 or v.dim() != 5:
raise ValueError("Batch prefill requires 5D vectorized K/V tensors")
num_heads = q.size(1) # num_heads = q.sizes()[1]
head_size_v = v.size(2) # head_size_v = v.size(2)
head_size_v = v.size(-2) # head_size_v = v.size(-2)
total_q = q.size(0) # total_q = q.size(0)

if out is None:
Expand Down Expand Up @@ -2679,6 +2688,9 @@ def mha_batch_prefill(
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
kv_last_page_lens: Optional[Tensor] = None,
block_table: Optional[Tensor] = None,
seqlen_k: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...

Expand All @@ -2704,6 +2716,9 @@ def _mha_batch_prefill(
return_softmax: bool = False,
zero_tensors: bool = False,
out: torch.Tensor = None,
kv_last_page_lens: torch.Tensor = None,
block_table: torch.Tensor = None,
seqlen_k: torch.Tensor = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -2734,6 +2749,9 @@ def _mha_batch_prefill(
q_descale,
k_descale,
v_descale,
kv_last_page_lens,
block_table,
seqlen_k,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
Expand All @@ -2758,19 +2776,28 @@ def mha_batch_prefill_func(
return_lse=False,
return_attn_probs=False,
out=None,
kv_last_page_lens=None,
block_table=None,
seqlen_k=None,
q_descale=None,
k_descale=None,
v_descale=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(2)
head_size_v_og = v.size(2)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])
head_size_q_og = q.size(-1)
k_vector_size = 16 // k.element_size()
Comment thread
valarLip marked this conversation as resolved.
if k.dim() != 5 or v.dim() != 5:
raise ValueError("Batch prefill requires 5D vectorized K/V tensors")
head_size_v_og = v.size(-2)
if head_size_q_og % k_vector_size != 0 or head_size_v_og % k_vector_size != 0:
raise ValueError("Vectorized KV requires head size divisible by vector size")
if k.size(-3) * k_vector_size != head_size_q_og:
raise ValueError("K vectorized layout does not match Q head size")
if k.size(-2) % k_vector_size != 0:
raise ValueError("Vectorized KV requires page size divisible by vector size")
if not k.is_contiguous() or not v.is_contiguous():
raise ValueError("Vectorized KV requires contiguous K/V")
out_padded, softmax_lse, S_dmask, rng_state = _mha_batch_prefill(
q,
k,
Expand All @@ -2790,6 +2817,9 @@ def mha_batch_prefill_func(
return_lse=return_lse,
return_softmax=return_attn_probs and dropout_p > 0,
out=out,
kv_last_page_lens=kv_last_page_lens,
block_table=block_table,
seqlen_k=seqlen_k,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
Expand Down
37 changes: 35 additions & 2 deletions csrc/cpp_itfs/mha_fwd_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,35 @@
has_sink);
}}

mha_batch_prefill_traits get_mha_batch_prefill_traits(int head_size_q,
int head_size_v,
std::string dtype,
bool is_group_mode,
bool has_logits_soft_cap,
mask_enum mask_type,
bias_enum bias_type,
bool has_lse,
bool has_dropout,
quant_scale_enum qscale_type,
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table,
int page_size,
bool skip_min_seqlen_q = false)
{{
return mha_batch_prefill_traits(head_size_q,
head_size_v,
dtype,
is_group_mode,
has_logits_soft_cap,
mask_type,
bias_type,
has_lse,
has_dropout,
qscale_type,
skip_min_seqlen_q,
kv_lookup_table,
page_size);
}}

mha_fwd_splitkv_traits get_mha_fwd_splitkv_traits(int head_size_q,
int head_size_v,
std::string dtype,
Expand Down Expand Up @@ -161,7 +190,10 @@
int head_size_q = args.hdim_q;
int head_size_v = args.hdim_v;
bool has_dropout = args.p_drop > 0.f;
auto traits = get_mha_fwd_traits(head_size_q,
auto kv_lookup_table = args.block_table_ptr != nullptr
? ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
: ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D;
auto traits = get_mha_batch_prefill_traits(head_size_q,
head_size_v,
q_dtype_str,
is_group_mode,
Expand All @@ -171,7 +203,8 @@
has_lse,
has_dropout,
qscale_type,
use_ext_asm);
kv_lookup_table,
args.page_block_size);
return fmha_batch_prefill(traits, args, stream_config);
}"""

Expand Down
19 changes: 9 additions & 10 deletions csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <hip/hip_runtime.h>
#include <cstdint>
#include <hip/hip_runtime.h>
#include <iostream>

enum class GPUArch
Expand All @@ -12,15 +12,14 @@ enum class GPUArch
gfx950
};


#define CHECK_COND(x) \
do { \
if (!(x)) { \
std::cerr << "check failed, file=" \
<< __FILE__ << ", line=" \
<< __LINE__ << std::endl; \
std::terminate(); \
} \
#define CHECK_COND(x) \
do \
{ \
if(!(x)) \
{ \
std::cerr << "check failed, file=" << __FILE__ << ", line=" << __LINE__ << std::endl; \
std::terminate(); \
} \
} while(0)

#define HIP_CALL(call) \
Expand Down
33 changes: 18 additions & 15 deletions csrc/include/groupnorm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,29 @@

namespace rocm_torch_x {

class __attribute__ ((visibility("hidden"))) GroupNorm final
class __attribute__((visibility("hidden"))) GroupNorm final
{
public:
public:
explicit GroupNorm() = default;
~GroupNorm() = default;
public:
~GroupNorm() = default;

public:
// return empty if not supported
std::optional<torch::Tensor> Run(
torch::Tensor x,
int num_groups,
torch::Tensor weights,
torch::Tensor bias,
float epsilon);
private:
template<typename T>
torch::Tensor launchGroupNormKernel(uint32_t num_groups, float epsilon,
const torch::Tensor x, const torch::Tensor weights, const torch::Tensor bias, hipStream_t stream);
std::optional<torch::Tensor>
Run(torch::Tensor x, int num_groups, torch::Tensor weights, torch::Tensor bias, float epsilon);

private:
template <typename T>
torch::Tensor launchGroupNormKernel(uint32_t num_groups,
float epsilon,
const torch::Tensor x,
const torch::Tensor weights,
const torch::Tensor bias,
hipStream_t stream);

void reserveMeanAccumulator(uint32_t nums_to_reserve, torch::Device device);
private:

private:
torch::Tensor mean_accumulator_;
};

Expand Down
38 changes: 37 additions & 1 deletion csrc/include/mha_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,42 @@ struct mha_fwd_traits : public fmha_fwd_traits
int how_v3_bf16_cvt;
};

struct mha_batch_prefill_traits : public fmha_batch_prefill_traits
{
mha_batch_prefill_traits(int head_size_q,
int head_size_v,
std::string dtype,
bool is_group_mode,
bool has_logits_soft_cap,
mask_enum mask_type,
bias_enum bias_type,
bool has_lse,
bool has_dropout,
quant_scale_enum qscale_type,
bool skip_min_seqlen_q,
ck_tile::BlockAttentionKVCacheLookupTableEnum kv_lookup_table,
int page_size)
: fmha_batch_prefill_traits{
head_size_q,
head_size_v,
dtype,
is_group_mode,
true, // is_v_rowmajor
has_logits_soft_cap,
mask_type,
bias_type,
has_lse,
has_dropout,
qscale_type,
skip_min_seqlen_q,
false, // has_sink
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
kv_lookup_table,
page_size}
{
}
};

struct mha_fwd_splitkv_traits : public fmha_fwd_splitkv_traits
{
mha_fwd_splitkv_traits(int head_size_q,
Expand Down Expand Up @@ -85,7 +121,7 @@ __attribute__((visibility("default"))) float mha_fwd(mha_fwd_args args,
bool has_lse,
quant_scale_enum qscale_type,
bool use_ext_asm,
bool has_sink = false,
bool has_sink = false,
int how_v3_bf16_cvt = 1,
const void* seqstart_q_padding_ptr = nullptr,
const void* seqstart_k_padding_ptr = nullptr,
Expand Down
57 changes: 30 additions & 27 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1008,33 +1008,36 @@ namespace py = pybind11;
py::arg("cu_seqlens_q_padded") = std::nullopt, \
py::arg("cu_seqlens_k_padded") = std::nullopt);

#define MHA_BATCH_PREFILL_PYBIND \
m.def("mha_batch_prefill", \
&aiter::torch_itfs::mha_batch_prefill, \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("cu_seqlens_q"), \
py::arg("kv_indptr"), \
py::arg("kv_page_indices"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("logits_soft_cap"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("out") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("q_descale") = std::nullopt, \
py::arg("k_descale") = std::nullopt, \
py::arg("v_descale") = std::nullopt, \
py::arg("gen") = std::nullopt);
#define MHA_BATCH_PREFILL_PYBIND \
m.def("mha_batch_prefill", \
&aiter::torch_itfs::mha_batch_prefill, \
py::arg("q"), \
py::arg("k"), \
py::arg("v"), \
py::arg("cu_seqlens_q"), \
py::arg("kv_indptr"), \
py::arg("kv_page_indices"), \
py::arg("max_seqlen_q"), \
py::arg("max_seqlen_k"), \
py::arg("dropout_p"), \
py::arg("softmax_scale"), \
py::arg("logits_soft_cap"), \
py::arg("zero_tensors"), \
py::arg("is_causal"), \
py::arg("window_size_left"), \
py::arg("window_size_right"), \
py::arg("return_softmax_lse"), \
py::arg("return_dropout_randval"), \
py::arg("out") = std::nullopt, \
py::arg("bias") = std::nullopt, \
py::arg("alibi_slopes") = std::nullopt, \
py::arg("q_descale") = std::nullopt, \
py::arg("k_descale") = std::nullopt, \
py::arg("v_descale") = std::nullopt, \
py::arg("kv_last_page_lens") = std::nullopt, \
py::arg("block_table") = std::nullopt, \
py::arg("seqlen_k") = std::nullopt, \
py::arg("gen") = std::nullopt);

#define MOE_OP_PYBIND \
m.def("topk_softmax", \
Expand Down
8 changes: 6 additions & 2 deletions csrc/include/torch/mha_batch_prefill.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace aiter {
namespace torch_itfs {
std::vector<at::Tensor>
mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
const at::Tensor& k, // [total_k, hk, d]
const at::Tensor& v, // [total_k, hk, d]
const at::Tensor& k, // [num_blocks, hk, d/8, block_size, 8]
const at::Tensor& v, // [num_blocks, hk, block_size/8, d, 8]
const at::Tensor& cu_seqlens_q, // [b+1]
const at::Tensor& kv_indptr, // [b+1]
const at::Tensor& kv_page_indices,
Expand All @@ -29,6 +29,10 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d]
std::optional<const at::Tensor> q_descale, // [1]
std::optional<const at::Tensor> k_descale, // [1]
std::optional<const at::Tensor> v_descale, // [1]
std::optional<const at::Tensor> kv_last_page_lens,
std::optional<const at::Tensor> block_table,
std::optional<const at::Tensor> seqlen_k,
std::optional<at::Generator> gen_);

} // namespace torch_itfs
} // namespace aiter
Loading
Loading