Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
4c23a9a
[fmha] Fix C++17 structured binding capture in batch_prefill kernel
LJ-underdog Apr 17, 2026
ae081a9
[fmha] Add StreamLLM sink support to batch_prefill pipeline
LJ-underdog Apr 15, 2026
18606d3
[fmha] Add batch_prefill test support to fmha_fwd_runner
LJ-underdog Apr 15, 2026
421f474
[fmha] Fix batch_prefill runner integration and validate correctness
LJ-underdog Apr 15, 2026
2e9ef7c
[fmha] Fix V window mismatch after sink→window transition in batch_pr…
LJ-underdog Apr 15, 2026
6dce6e7
[fmha] Broaden batch_prefill test filter to cover fp16/bf16, lse/nlse…
LJ-underdog Apr 15, 2026
a7bd8f5
[fmha] Remove redundant assignments in batch_prefill init_args branch
LJ-underdog Apr 17, 2026
c449990
[fmha] Fix gfx11 VECTORIZED_LAYOUT incompatibility in batch_prefill c…
LJ-underdog Apr 17, 2026
ce2df18
[fmha] Skip b64x128 tile for gfx11 in batch_prefill codegen
LJ-underdog Apr 17, 2026
68c2f5c
[fmha] Skip all batch_prefill kernels for gfx11 targets
LJ-underdog Apr 17, 2026
1e4d82f
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 17, 2026
7afd9dd
[fmha] Extend batch_prefill gfx skip to all non-gfx9 architectures
LJ-underdog Apr 17, 2026
c5b1c23
[fmha] Remove batch_prefill-specific codegen filter from CMakeLists
LJ-underdog Apr 18, 2026
f18e006
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 20, 2026
b8ee22a
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 20, 2026
618e4c3
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 21, 2026
5904d25
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 21, 2026
2d135a8
Merge branch 'develop' into ck/lj/batch_prefill_sink
LJ-underdog Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ if(NOT INST_TARGETS)
endif()

# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(BUILD_TESTING)
Expand Down Expand Up @@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,80,128,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
Expand Down Expand Up @@ -174,6 +173,13 @@ else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
endif()

# conditionally enable call to the batch_prefill API in fmha_fwd example and tests
if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1)
else()
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0)
endif()

# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
{F_qscale},
{F_occupancy},
false,
{F_sink},
{F_page_size},
{F_kv_memory_layout},
{F_kv_lookup_table}>;
Expand Down Expand Up @@ -124,7 +125,7 @@
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;

using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;

#include <iostream>

Expand Down Expand Up @@ -201,9 +202,9 @@
}}
"""

FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
Expand Down Expand Up @@ -247,6 +248,7 @@ class FmhaFwdApiTrait:
skpad: str
dpad: str
dvpad: str
sink: str # t/f
constraint: CppConstraint
kv_memory_layout: str
kv_lookup_table: str
Expand Down Expand Up @@ -343,6 +345,7 @@ class FmhaFwdPipeline:
F_dropout: str #
F_qscale: str # no/pertensor
F_mask: str # value from MASK_MAP
F_sink: str # t/f (StreamLLM sink tokens)
F_kv_memory_layout: str #
F_kv_lookup_table: str #
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
Expand Down Expand Up @@ -406,6 +409,11 @@ def pad_name() -> str:
else:
n += "_nqscale"

if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"

n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
return n

Expand Down Expand Up @@ -472,6 +480,7 @@ def api(self) -> str:
trait.kv_lookup_table
],
F_page_size=trait.page_size,
F_sink=BOOL_MAP[trait.sink],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
Expand Down Expand Up @@ -578,6 +587,7 @@ def template(self) -> str:
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
F_page_size=self.F_page_size,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
)

@property
Expand Down Expand Up @@ -617,6 +627,7 @@ def api_trait(self) -> FmhaFwdApiTrait:
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
sink=self.F_pipeline.F_sink,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
Expand Down Expand Up @@ -655,6 +666,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
bias,
lse,
dropout,
sink,
kv_memory_layout,
kv_lookup_table,
) in itertools.product(
Expand All @@ -663,12 +675,13 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
["t", "f"],
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
elif dtype in ["fp8bf16"]:
# no need lse/dropout kernels
# no need lse/dropout/sink kernels
for (
logits,
qscale,
Expand All @@ -684,7 +697,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
SUPPORTED_KV_MEMORY_LAYOUT,
SUPPORTED_KV_LOOKUP_TABLE,
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
else:
assert False
return pipelines
Expand All @@ -701,20 +714,34 @@ def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:


def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
targets: Optional[List[str]] = None
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
# non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different
# buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets.
has_non_gfx9 = targets is not None and any(
not t.startswith("gfx9") for t in targets
Comment thread
LJ-underdog marked this conversation as resolved.
)
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future

gen = list()
api_pool = FmhaFwdApiPool(mask_impl)

if has_non_gfx9:
return api_pool, gen

for dtype in FWD_DTYPE_MAP.keys():
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
# batch_prefill pipeline requires group mode (static_assert in pipeline problem)
if mode != "group":
continue
for tile, pipeline in itertools.product(
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
):
Expand Down Expand Up @@ -829,7 +856,7 @@ def write_blobs(
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
Expand All @@ -844,7 +871,7 @@ def list_blobs(
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,7 @@ template <ck_tile::index_t HDim_,
bool kPadDv_,
bool kUseTrLoad_,
bool kSkipMinSeqlenQ_ = false,
bool kHasSink_ = false,
ck_tile::index_t kPageBlockSize_ = 1,
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
Expand Down Expand Up @@ -1480,7 +1481,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
kPadDv_,
kUseTrLoad_,
kSkipMinSeqlenQ_,
false>
kHasSink_>
{
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
static constexpr auto kKVLookupTable = kKVLookupTable_;
Expand Down
Loading
Loading