Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0f14ca3
fmha: introduce FmhaSinkMode enum to unify compile-time sink control
LJ-underdog Mar 27, 2026
c3f4466
fmha: remove unused sink constants from pipeline/problem/kernel layers
LJ-underdog Mar 30, 2026
7b89aa3
fmha: rename has_sink to has_stream_sink in runtime traits
LJ-underdog Mar 31, 2026
d6b11a9
fmha: deduplicate SINK_*_MAP by importing from fmha_fwd
LJ-underdog Mar 31, 2026
bbef1af
fmha: remove dead kHasSink constant
LJ-underdog Mar 31, 2026
0b4433b
fmha: remove redundant tile_fmha_traits include from fmha_fwd_kernel
LJ-underdog Mar 31, 2026
3c7970b
fmha: guard pagedkv sink_value with if constexpr(kHasGptOssSink)
LJ-underdog Mar 31, 2026
395754e
fmha: guard splitkv sink_value with if constexpr(kHasGptOssSink)
LJ-underdog Mar 31, 2026
eb966f0
fmha: unify kSinkMode access path and reduce tile_fmha_traits includes
LJ-underdog Mar 31, 2026
eb25153
fmha: replace kSinkMode enum comparison with bool flags in kernel Get…
LJ-underdog Mar 31, 2026
ca10670
fmha: remove redundant tile_fmha_traits include from pipeline files
LJ-underdog Mar 31, 2026
66b7474
fmha: revert stray changes in block_fmha_pipeline_qr_ks_vs
LJ-underdog Mar 31, 2026
6375bf4
fmha: apply clang-format alignment fixes
LJ-underdog Mar 31, 2026
22be529
fmha: validate --sink argument in generate.py
LJ-underdog Mar 31, 2026
d3aba5b
Merge branch 'develop' into lj/reorg_sink
poyenc Apr 3, 2026
1d1f7a0
[CK Tile] Gate sink smoke tests behind opt-in flags in smoke_test_fwd.sh
LJ-underdog Apr 7, 2026
902698a
Merge branch 'develop' into lj/reorg_sink
LJ-underdog Apr 7, 2026
5f46a57
Merge branch 'develop' into lj/reorg_sink
LJ-underdog Apr 13, 2026
368851e
[CK Tile][FMHA] Address poyenc code review comments on PR #6057
LJ-underdog Apr 13, 2026
def6cab
[CK Tile][FMHA] Fix code alignment for sink-related fields
LJ-underdog Apr 14, 2026
74b19e9
[CK Tile][FMHA] Remove stray compiler workaround from sink mode refactor
LJ-underdog Apr 15, 2026
973e041
[CK Tile][FMHA] Fix F_sink="f" → "none" after FmhaSinkMode refactor
LJ-underdog Apr 15, 2026
c17001a
[CK] Skip fp16 dropout d256/d24 batch tests for compiler VGPR aliasin…
LJ-underdog Apr 16, 2026
841bfa0
[CK] Restrict fp16 dropout d256/d24 batch skip to gfx950
LJ-underdog Apr 16, 2026
8b19d58
Merge remote-tracking branch 'origin/develop' into lj/reorg_sink
LJ-underdog Apr 16, 2026
def192e
[CK] Fix clang-format: wrap long GTEST_SKIP string literal
LJ-underdog Apr 16, 2026
22d8f52
[CK] Remove gfx95 guard from fp16 dropout d256/d24 batch skip (ROCm 7…
LJ-underdog Apr 17, 2026
f6e785f
[CK] Add fp16 dropout d256 batch skip for ROCm 7.1.x to Dropout test
LJ-underdog Apr 17, 2026
77b26ad
[fmha] Skip fp16 d256 batch bias=e mask dropout on ROCm 7.1.x in smok…
LJ-underdog Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,23 @@ set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}
list(JOIN INST_TARGETS , FMHA_TARGETS_ARG)

string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_OPTDIM "32,64,80,128,256" CACHE STRING
"comma-separated list of hdim values to optimize for")
set(FMHA_FWD_FILTER "" CACHE STRING
"fnmatch filter pattern for fwd kernel instances. Empty = no filter.")
set(FMHA_FWD_SINK_MODES "none" CACHE STRING
"comma-separated list of sink modes to generate (none,stream,gptoss,both). Default: none.")

set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,80,128,256
# --filter fmha_fwd...
--optdim ${FMHA_FWD_OPTDIM}
--sink ${FMHA_FWD_SINK_MODES}
)
if(FMHA_FWD_FILTER)
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter ${FMHA_FWD_FILTER})
endif()
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
Expand Down Expand Up @@ -98,6 +108,7 @@ add_custom_command(
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile FMHA FWD kernels"
VERBATIM
)

add_custom_command(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,11 @@ def write_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
# sink_modes is intentionally not forwarded: batch_prefill does not yet support
# StreamLLM/GPT-OSS sink kernels. The parameter exists only for API uniformity
# with other fwd handlers. Non-"none" modes are silently treated as "none".
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
Expand All @@ -842,7 +846,9 @@ def list_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
# sink_modes is intentionally not forwarded: see write_blobs for rationale.
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,11 @@ def write_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
# sink_modes is intentionally not forwarded: appendkv does not support sink kernels.
# The parameter exists only for API uniformity with other fwd handlers.
# Non-"none" modes are silently treated as "none".
api_pool, kernels = get_fwd_appendkv_blobs(
targets, kernel_filter, receipt, mask_impl, optdim_list
)
Expand All @@ -509,11 +513,16 @@ def list_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
# sink_modes is intentionally not forwarded: see write_blobs for rationale.
with file_path.open("a") as f:
_, kernels = get_fwd_appendkv_blobs(
targets, kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
f.write((file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + "\n")
f.write(
(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix()
+ "\n"
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
FMHA_FWD_API_PER_ARCH,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
SINK_MODE_MAP,
SINK_MODE_DISPATCH_MAP,
SINK_NAME_MAP,
)


Expand Down Expand Up @@ -74,7 +77,7 @@
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy},
{F_sink}>;
{F_sink_mode}>;

using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
Expand Down Expand Up @@ -118,7 +121,7 @@
}} // anonymous namespace

using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;

#pragma clang diagnostic push
Expand Down Expand Up @@ -280,8 +283,8 @@
"""

FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;

// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
Expand Down Expand Up @@ -427,7 +430,7 @@ class FmhaFwdSplitKVPipeline:
F_lse: str #
F_squant: str #
F_pagedkv: str # t/f
F_sink: str # t/f
F_sink: str # "none" / "stream" / "gptoss" / "both"
F_mask: str # value from MASK_MAP

@property
Expand Down Expand Up @@ -488,10 +491,7 @@ def pad_name() -> str:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
n += SINK_NAME_MAP[self.F_sink]
return n


Expand Down Expand Up @@ -574,7 +574,9 @@ def api(self) -> str:
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_sink_mode=SINK_MODE_MAP[trait.sink],
F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0],
F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
Expand Down Expand Up @@ -675,7 +677,7 @@ def template(self) -> str:
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
Expand Down Expand Up @@ -740,7 +742,9 @@ def filename(self) -> str:

class KernelComponentFactoryBase:
@staticmethod
def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
def get_pipelines(
dtype, hdim, mask_impl, sink_modes=("none",)
) -> List[FmhaFwdSplitKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let "t" padding to appear later!!
Expand All @@ -754,7 +758,7 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
sink_modes,
):
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
Expand All @@ -764,8 +768,8 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
Expand Down Expand Up @@ -891,7 +895,12 @@ def get_factory(target: str):


def get_fwd_splitkv_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
targets: List[str],
kernel_filter: Optional[str],
receipt,
mask_impl,
optdim_list,
sink_modes=("none",),
) -> List[FmhaFwdSplitKVKernel]:
Kernel = FmhaFwdSplitKVKernel

Expand All @@ -907,7 +916,7 @@ def get_fwd_splitkv_blobs(
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl, sink_modes):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
Expand Down Expand Up @@ -942,7 +951,7 @@ def get_fwd_splitkv_blobs(
# FlashAttention splitkv paths use softcap-disabled kernels only.
cond &= pipeline.F_logits == "f"
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
cond &= pipeline.F_sink == "none"
if not cond:
continue
# PyTorch integration
Expand All @@ -952,7 +961,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
cond &= pipeline.F_sink == "none"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
Expand Down Expand Up @@ -1058,6 +1067,7 @@ def write_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
filter_list = filter_list.split("@")
filter_list.extend([""] * (2 - len(filter_list)))
Expand All @@ -1068,7 +1078,7 @@ def write_blobs(
for kernel in combine_kernels:
write_single_kernel(kernel, output_dir)
kernels = get_fwd_splitkv_blobs(
targets, filter_list[1], receipt, mask_impl, optdim_list
targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes
)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
Expand Down Expand Up @@ -1129,6 +1139,7 @@ def list_blobs(
receipt,
optdim_list,
mask_impl,
sink_modes=("none",),
) -> None:
filter_list = filter_list.split("@")
filter_list.extend([""] * (2 - len(filter_list)))
Expand All @@ -1140,7 +1151,7 @@ def list_blobs(
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
kernels = get_fwd_splitkv_blobs(
targets, filter_list[1], receipt, mask_impl, optdim_list
targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes
)
for kernel in kernels:
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
Expand Down
Loading
Loading