From 0f14ca38d7ba83e751d11eb8f0bf8bb7d693335f Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Fri, 27 Mar 2026 05:25:16 +0000 Subject: [PATCH 01/25] fmha: introduce FmhaSinkMode enum to unify compile-time sink control Replaces the single bool kHasSink_ with a FmhaSinkMode enum (kNone/kStreamLLM/kGptOss/kBoth) across all FMHA forward pipelines, kernels, and codegen. Key changes: - Add FmhaSinkMode enum to tile_fmha_traits.hpp; derive kHasSink, kHasStreamSink, kHasGptOssSink constants in all traits structs - Pipeline files: replace runtime __builtin_isinf_sign(sink_v) checks with if constexpr(kHasGptOssSink); replace kHasSink with kHasStreamSink for StreamLLM-only paths (tile range, mask, dropout seq offset, KV window jump) - Kernel files: sink_value computation guarded by kHasGptOssSink; kernel naming extended to _nsink/_ssink/_gsink/_bsink - fmha_fwd.hpp: fmha_fwd_traits_/pagedkv/splitkv use FmhaSinkMode; runtime traits structs add has_gptoss_sink field - generate.py: add --sink argument (default: none); codegen py files add sink_modes parameter to get_pipelines/write_blobs/list_blobs - CMakeLists.txt: add FMHA_FWD_SINK_MODES cache variable (default: none) to control which sink variants are compiled This eliminates runtime overhead for no-sink kernels (kHasGptOssSink= false compiles out all GPT-OSS paths), and fixes the VGPR register allocation bug in dropout kernels caused by sink_v*scale_s in the early-exit LSE path. Correctness validated with CPU reference for all GPT-OSS sink combinations: no-mask, causal mask, alibi, LSE, early-exit, group mode. --- .../example/ck_tile/01_fmha/CMakeLists.txt | 15 ++- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 2 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 108 ++++++++++++------ .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 7 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 76 ++++++++---- .../codegen/ops/fmha_pagedkv_prefill.py | 75 ++++++++---- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 42 +++++-- .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 3 +- .../example/ck_tile/01_fmha/generate.py | 40 ++++++- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 26 +++-- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 7 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 7 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 18 ++- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 17 ++- ...litkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 38 ++++-- ...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 63 ++++++---- .../pipeline/block_fmha_pipeline_problem.hpp | 10 ++ .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 72 +++++++----- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 19 +-- ...ck_fmha_pipeline_qr_ks_vs_async_trload.hpp | 14 ++- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 51 +++++++-- 21 files changed, 511 insertions(+), 199 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt b/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e0..8153982f57d 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt +++ b/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt @@ -43,13 +43,23 @@ set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS} list(JOIN INST_TARGETS , FMHA_TARGETS_ARG) string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +set(FMHA_FWD_OPTDIM "32,64,80,128,256" CACHE STRING + "comma-separated list of hdim values to optimize for") +set(FMHA_FWD_FILTER "" CACHE STRING + "fnmatch filter pattern for fwd kernel instances. Empty = no filter.") +set(FMHA_FWD_SINK_MODES "none" CACHE STRING + "comma-separated list of sink modes to generate (none,stream,gptoss,both). Default: none.") + set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 32,64,80,128,256 - # --filter fmha_fwd... + --optdim ${FMHA_FWD_OPTDIM} + --sink ${FMHA_FWD_SINK_MODES} ) +if(FMHA_FWD_FILTER) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter ${FMHA_FWD_FILTER}) +endif() set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} @@ -98,6 +108,7 @@ add_custom_command( --output_dir ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${CODE_GEN_SCRIPTS} COMMENT "Generate CK Tile FMHA FWD kernels" + VERBATIM ) add_custom_command( diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 35e8c1be49d..1eb22bb9871 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -828,6 +828,7 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: @@ -842,6 +843,7 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 18490681613..c194a07f756 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -42,6 +42,30 @@ "mxfp4": 4, } +# Maps FmhaSinkMode string values to C++ FmhaSinkMode enum +SINK_MODE_MAP = { + "none": "ck_tile::FmhaSinkMode::kNone", + "stream": "ck_tile::FmhaSinkMode::kStreamLLM", + "gptoss": "ck_tile::FmhaSinkMode::kGptOss", + "both": "ck_tile::FmhaSinkMode::kBoth", +} + +# For backward compat dispatch check: map sink mode to (has_sink, has_gptoss_sink) +SINK_MODE_DISPATCH_MAP = { + "none": ("false", "false"), + "stream": ("true", "false"), + "gptoss": ("false", "true"), + "both": ("true", "true"), +} + +# Legacy sink name suffix for kernel files +SINK_NAME_MAP = { + "none": "_nsink", + "stream": "_ssink", + "gptoss": "_gsink", + "both": "_bsink", +} + K0_MAX_SUBMAX_MAP = { 32: 32, 48: 48, @@ -88,7 +112,7 @@ {F_qscale}, {F_occupancy}, {F_skip}, - {F_sink}>; + {F_sink_mode}>; using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -125,7 +149,7 @@ using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink_mode}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -240,9 +264,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink_mode}>; return fmha_fwd_(s, a); }} """ @@ -396,7 +420,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false - F_sink: str # true/false + F_sink: str # "none" / "stream" / "gptoss" / "both" F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -467,10 +491,7 @@ def pad_name() -> str: n += "_trload" else: n += "_ntrload" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -560,7 +581,9 @@ def has_traits(node) -> bool: F_trload=BOOL_MAP[trait.tr_load], F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], F_qscale=QSCALE_MAP[trait.qscale], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, @@ -701,7 +724,7 @@ def render(self) -> str: F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], ) @property @@ -974,7 +997,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: # support this in future @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later @@ -990,7 +1013,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1004,7 +1027,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1016,6 +1039,14 @@ def get_pipelines( # TODO: rocm 6.2 compiler problem if using qr_async for bias case pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + elif mask not in ("s_no", "no") and ( + (bias == "no" and dropout == "t") + or (bias == "alibi" and dropout == "f") + ): + # TODO: compiler problem with qr_async for IsMasking=true + + # (no_bias+dropout) or (alibi+ndropout) combinations + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1028,7 +1059,7 @@ def get_pipelines( ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], - ["f", "t"], + sink_modes, ): if hdim == 64: pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip @@ -1079,10 +1110,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( - dtype, hdim, hdim_v, receipt, mask_impl + dtype, hdim, hdim_v, receipt, mask_impl, sink_modes ) if dtype in cls._DT_FP16_BF16: qscale = "no" @@ -1093,7 +1124,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -1110,7 +1141,7 @@ def get_pipelines( # qr_async_trload_v3 only supports (generic) causal mask for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="none")) # fmt: skip elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: # no need dropout kernels @@ -1121,7 +1152,7 @@ def get_pipelines( ["mx"], get_mask_map(mask_impl).keys(), ["no"], - ["f", "t"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip @@ -1158,7 +1189,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in cls._DT_FP16_BF16: @@ -1170,7 +1201,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): # Keep only ttff/tttt for gfx11: ffff path is often similar or worse # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1230,7 +1261,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in cls._DT_FP16_BF16: @@ -1242,22 +1273,23 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for logits, qscale, mask, bias, sink in itertools.product( ["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], + sink_modes, ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip return pipelines @@ -1311,7 +1343,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" - cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_sink == "none" return cond return Product(name="Flash attention integration", rule=fit) @@ -1408,7 +1440,12 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: def get_fwd_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + sink_modes=("none",), ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool() @@ -1430,7 +1467,10 @@ def get_fwd_blobs( ) for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + tiles, + factory.get_pipelines( + dtype, hdim, hdim_v, receipt, mask_impl, sink_modes + ), ): problem_ctx = ProblemContext( dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v @@ -1493,9 +1533,10 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -1509,10 +1550,11 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 793a743df74..7f1851380cb 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -493,6 +493,7 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_appendkv_blobs( targets, kernel_filter, receipt, mask_impl, optdim_list @@ -509,6 +510,7 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_appendkv_blobs( @@ -516,4 +518,7 @@ def list_blobs( ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + + "\n" + ) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index e0ccde8a6b7..4de16a706a7 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -40,6 +40,27 @@ "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", } +SINK_MODE_MAP = { + "none": "ck_tile::FmhaSinkMode::kNone", + "stream": "ck_tile::FmhaSinkMode::kStreamLLM", + "gptoss": "ck_tile::FmhaSinkMode::kGptOss", + "both": "ck_tile::FmhaSinkMode::kBoth", +} + +SINK_MODE_DISPATCH_MAP = { + "none": ("false", "false"), + "stream": ("true", "false"), + "gptoss": ("false", "true"), + "both": ("true", "true"), +} + +SINK_NAME_MAP = { + "none": "_nsink", + "stream": "_ssink", + "gptoss": "_gsink", + "both": "_bsink", +} + FMHA_FWD_SPLITKV_KERNEL_BODY = """ #include @@ -74,7 +95,7 @@ kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, {F_occupancy}, - {F_sink}>; + {F_sink_mode}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -118,7 +139,7 @@ }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -280,8 +301,8 @@ """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -427,7 +448,7 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f - F_sink: str # t/f + F_sink: str # "none" / "stream" / "gptoss" / "both" F_mask: str # value from MASK_MAP @property @@ -488,10 +509,7 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -574,7 +592,9 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -675,7 +695,7 @@ def template(self) -> str: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -740,7 +760,9 @@ def filename(self) -> str: class KernelComponentFactoryBase: @staticmethod - def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: + def get_pipelines( + dtype, hdim, mask_impl, sink_modes=("none",) + ) -> List[FmhaFwdSplitKVPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! @@ -754,7 +776,7 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], - ["t", "f"], + sink_modes, ): pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip @@ -764,8 +786,8 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -891,7 +913,12 @@ def get_factory(target: str): def get_fwd_splitkv_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list + targets: List[str], + kernel_filter: Optional[str], + receipt, + mask_impl, + optdim_list, + sink_modes=("none",), ) -> List[FmhaFwdSplitKVKernel]: Kernel = FmhaFwdSplitKVKernel @@ -907,7 +934,7 @@ def get_fwd_splitkv_blobs( for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl, sink_modes): if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not @@ -940,7 +967,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # PyTorch integration @@ -950,7 +977,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1056,6 +1083,7 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) @@ -1066,7 +1094,7 @@ def write_blobs( for kernel in combine_kernels: write_single_kernel(kernel, output_dir) kernels = get_fwd_splitkv_blobs( - targets, filter_list[1], receipt, mask_impl, optdim_list + targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes ) for kernel in kernels: write_single_kernel(kernel, output_dir) @@ -1127,6 +1155,7 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) @@ -1138,8 +1167,11 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") kernels = get_fwd_splitkv_blobs( - targets, filter_list[1], receipt, mask_impl, optdim_list + targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME).as_posix() + + "\n" + ) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 1ac1f1c38a7..787c00ff778 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -34,6 +34,27 @@ ) +SINK_MODE_MAP = { + "none": "ck_tile::FmhaSinkMode::kNone", + "stream": "ck_tile::FmhaSinkMode::kStreamLLM", + "gptoss": "ck_tile::FmhaSinkMode::kGptOss", + "both": "ck_tile::FmhaSinkMode::kBoth", +} + +SINK_MODE_DISPATCH_MAP = { + "none": ("false", "false"), + "stream": ("true", "false"), + "gptoss": ("false", "true"), + "both": ("true", "true"), +} + +SINK_NAME_MAP = { + "none": "_nsink", + "stream": "_ssink", + "gptoss": "_gsink", + "both": "_bsink", +} + FMHA_FWD_PAGEDKV_PIPELINE_MAP = { "qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" } @@ -66,7 +87,7 @@ {F_squant}, {F_occupancy}, {F_skip}, - {F_sink}>; + {F_sink_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -101,7 +122,7 @@ ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink_mode}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -130,9 +151,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink_mode}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -258,7 +279,7 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false - F_sink: str # true/false + F_sink: str # "none" / "stream" / "gptoss" / "both" @property def name(self) -> str: @@ -323,10 +344,7 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -370,7 +388,9 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -488,7 +508,7 @@ def template(self) -> str: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -541,7 +561,9 @@ def api_trait(self) -> FmhaFwdApiTrait: class KernelComponentFactoryBase: @staticmethod - def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: + def get_pipelines( + dtype, hdim, mask_impl, sink_modes=("none",) + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!! @@ -555,17 +577,17 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: BIAS_MAP.keys(), ["t"], ["f"], - ["t", "f"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels - for logits, mask, bias in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + for logits, mask, bias, sink in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), sink_modes ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", sink)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -655,7 +677,12 @@ def get_factory(target: str): def get_fwd_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + sink_modes=("none",), ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) @@ -669,7 +696,7 @@ def get_fwd_blobs( for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl, sink_modes): # if pipeline.F_pagedkv == "f": # continue if mode == "group": @@ -709,7 +736,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # PyTorch integration @@ -719,7 +746,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # Aiter(mha_fwd) integration @@ -773,9 +800,10 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -789,10 +817,11 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 521f1e4738e..0c6f0372b6a 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1393,8 +1393,8 @@ template + bool kSkipMinSeqlenQ_ = false, + ck_tile::FmhaSinkMode kSinkMode_ = ck_tile::FmhaSinkMode::kNone> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1420,7 +1420,12 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasSink = (kSinkMode != ck_tile::FmhaSinkMode::kNone); + static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || + kSinkMode == ck_tile::FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode == ck_tile::FmhaSinkMode::kGptOss || kSinkMode == ck_tile::FmhaSinkMode::kBoth); }; template + ck_tile::FmhaSinkMode::kNone> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; @@ -1506,8 +1511,8 @@ template + bool kSkipMinSeqlenQ_ = false, + ck_tile::FmhaSinkMode kSinkMode_ = ck_tile::FmhaSinkMode::kNone> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1532,7 +1537,12 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasSink = (kSinkMode != ck_tile::FmhaSinkMode::kNone); + static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || + kSinkMode == ck_tile::FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode == ck_tile::FmhaSinkMode::kGptOss || kSinkMode == ck_tile::FmhaSinkMode::kBoth); }; template @@ -1555,7 +1565,7 @@ template @@ -1670,7 +1685,8 @@ struct fmha_fwd_traits bool has_dropout; quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; - bool has_sink = false; + bool has_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1689,7 +1705,8 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; - bool has_sink = false; + bool has_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; @@ -1709,7 +1726,8 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant = false; - bool has_sink = false; + bool has_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 40b80063810..3b3de84fd06 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1133,7 +1133,8 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; - traits.has_sink = mask.sink > 0 ? true : false; + traits.has_sink = mask.sink > 0 ? true : false; // StreamLLM sink + traits.has_gptoss_sink = init_sink_value != 0; // GPT-OSS sink traits.has_lse = lse; if constexpr(std::is_same_v>) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/generate.py b/projects/composablekernel/example/ck_tile/01_fmha/generate.py index a5a2d085635..5e86cf83f18 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/generate.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/generate.py @@ -46,6 +46,7 @@ def write_blobs( optdim_list: List[int], receipt, mask_impl, + sink_modes=("none",), ) -> None: if output_dir is None: output_dir = Path(__file__).parent @@ -56,7 +57,18 @@ def write_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) + if api == "bwd": + handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) + else: + handler( + targets, + output_dir, + kernel_filter, + receipt, + optdim_list, + mask_impl, + sink_modes, + ) # list all the files that will be generated @@ -68,6 +80,7 @@ def list_blobs( optdim_list: List[int], receipt, mask_impl, + sink_modes=("none",), ) -> None: assert output_file is not None file_path = Path(output_file) @@ -77,7 +90,18 @@ def list_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) + if api == "bwd": + handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) + else: + handler( + targets, + file_path, + kernel_filter, + receipt, + optdim_list, + mask_impl, + sink_modes, + ) if __name__ == "__main__": @@ -150,12 +174,22 @@ def list_blobs( + "eg. --optdim=32,64,128,256", ) + parser.add_argument( + "--sink", + default="none", + required=False, + help="comma-separated list of sink modes to generate instances for. " + + "Valid values: none, stream, gptoss, both. Default: none (only no-sink instances). " + + "eg. --sink=none,stream,gptoss", + ) + args = parser.parse_args() targets = args.targets.split(",") api_list = args.direction.split(",") filter_list = args.filter.split(",") filter_list.extend([""] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(",")] + sink_modes = tuple(args.sink.split(",")) if args.list_blobs is not None: list_blobs( @@ -166,6 +200,7 @@ def list_blobs( optdim_list, int(args.receipt), mask_impl=args.mask, + sink_modes=sink_modes, ) else: write_blobs( @@ -176,4 +211,5 @@ def list_blobs( optdim_list, int(args.receipt), mask_impl=args.mask, + sink_modes=sink_modes, ) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 16f5b00bb11..8de97eedc6e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -67,7 +68,10 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; static constexpr bool kHasSink = FmhaPipeline::kHasSink; + static constexpr bool kHasStreamSink = FmhaPipeline::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -1439,10 +1443,14 @@ struct FmhaFwdKernel long_index_t batch_offset_q_descale = 0; long_index_t batch_offset_k_descale = 0; long_index_t batch_offset_v_descale = 0; - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + // GPT-OSS sink value: only computed and passed when kHasGptOssSink=true. + // When kHasGptOssSink=false, pipelines do not use sink_value at all. + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { @@ -2184,10 +2192,12 @@ struct FmhaFwdKernel constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 89bd22c4715..76e635540df 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -55,7 +56,10 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; static constexpr bool kHasSink = FmhaPipeline::kHasSink; + static constexpr bool kHasStreamSink = FmhaPipeline::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -102,7 +106,8 @@ struct FmhaFwdPagedKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + + (kSinkMode == FmhaSinkMode::kNone ? "_nsink" : kSinkMode == FmhaSinkMode::kStreamLLM ? "_ssink" : kSinkMode == FmhaSinkMode::kGptOss ? "_gsink" : "_bsink"); #undef _SS_ #undef _TS_ // clang-format on diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index d6c4d70fee5..55b9f63e79c 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -51,7 +52,10 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; + static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -102,7 +106,8 @@ struct FmhaFwdSplitKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + + (kSinkMode == FmhaSinkMode::kNone ? "_nsink" : kSinkMode == FmhaSinkMode::kStreamLLM ? "_ssink" : kSinkMode == FmhaSinkMode::kGptOss ? "_gsink" : "_bsink"); #undef _SS_ #undef _TS_ // clang-format on diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index a8b94b6e417..b4a2db32743 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" @@ -291,6 +292,10 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; + static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] // This avoids explicit P *= scale_p and v_descale /= scale_p operations @@ -526,8 +531,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { + // sink_v is always valid when kHasGptOssSink=true; no isinf check needed. #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) @@ -561,7 +567,15 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + if constexpr(kHasGptOssSink) + { + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); + } + else + { + set_tile(lse, -numeric::infinity()); + } store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 3f6b9bc44fc..d0a61a81bd1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -58,7 +59,10 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,7 +233,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -249,7 +253,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -275,7 +279,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -314,7 +319,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -360,7 +365,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -530,7 +535,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 1af244751ba..07461516348 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -57,7 +58,10 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -255,10 +259,18 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + if constexpr(kHasGptOssSink) { - set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); - set_tile(l, SMPLComputeDataType{1.0f}); + if(i_split == 0) + { + set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -268,7 +280,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); else @@ -293,7 +305,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); if(get_thread_local_1d_id() < kM0) { store_tile(lse_acc_dram_window_tmp, @@ -310,10 +323,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto [start, end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); - if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + if constexpr(kHasGptOssSink) { - set_tile(m, SMPLComputeDataType{sink_v}); - set_tile(l, SMPLComputeDataType{1.0f}); + if(start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } } } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; @@ -345,7 +361,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -409,7 +425,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS clear_tile(s_acc); // initialize C const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -585,7 +601,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask( [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 842b48013a2..88cf2e3f57f 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" @@ -57,7 +58,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,18 +233,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + if constexpr(kHasGptOssSink) { + if(i_split == 0) + { #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); - else - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else - set_tile(m, sink_v); + set_tile(m, sink_v); #endif - set_tile(l, SMPLComputeDataType{1.0f}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -250,7 +262,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); else @@ -277,7 +289,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); } @@ -292,18 +305,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto [start, end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); - if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + if constexpr(kHasGptOssSink) { + if(start >= end) + { #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); - else - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else - set_tile(m, sink_v); + set_tile(m, sink_v); #endif - set_tile(l, SMPLComputeDataType{1.0f}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -341,7 +362,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -388,7 +409,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -562,7 +583,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask( [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9e..cb4efbbef4e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" namespace ck_tile { @@ -72,7 +73,10 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; static constexpr bool kHasSink = Traits::kHasSink; + static constexpr bool kHasStreamSink = Traits::kHasStreamSink; + static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; template = 0) + if constexpr(kHasGptOssSink) { + // sink_v is always valid when kHasGptOssSink=true; no isinf check needed. #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -304,7 +309,7 @@ struct BlockFmhaPipelineQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -333,9 +338,12 @@ struct BlockFmhaPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + // Precompute sink_lse to avoid sink_v*scale_s in set_tile, which + // triggers a compiler register-allocation bug for dropout kernels. + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); } else { @@ -634,7 +642,7 @@ struct BlockFmhaPipelineQRKSVS #endif } } - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); @@ -665,7 +673,7 @@ struct BlockFmhaPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto&&... args) { return variant.LogitsSinkMask(std::forward(args)...); @@ -804,7 +812,7 @@ struct BlockFmhaPipelineQRKSVS auto randval_ptr = reinterpret_cast(smem_ptr); index_t seq_offset = [&]() { - if constexpr(!kHasSink) + if constexpr(!kHasStreamSink) return seqlen_k_start + i_total_loops * kN0; const bool in_sink_phase = (num_sink_loop > i_total_loops); @@ -945,7 +953,7 @@ struct BlockFmhaPipelineQRKSVS }); } // move K tile windows - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) { @@ -971,7 +979,31 @@ struct BlockFmhaPipelineQRKSVS } } while(++i_total_loops < num_total_loop); - // store lse + // finally, O -- normalize BEFORE storing LSE to avoid VGPR aliasing: + // the lse tile creation (below) may reuse o_acc VGPRs; normalizing first + // ensures o_acc is fully consumed before those registers can be reused. + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + // When bias carries -inf masks the denominator can be zero; guard the normalization + // so we do not divide by zero after a fully masked row. + if constexpr(FmhaMask::IsMasking || + BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + // store lse -- AFTER O normalization to prevent VGPR reuse corruption if constexpr(kStoreLSE) { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); @@ -1013,28 +1045,6 @@ struct BlockFmhaPipelineQRKSVS store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } - // finally, O - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - // When bias carries -inf masks the denominator can be zero; guard the normalization - // so we do not divide by zero after a fully masked row. - if constexpr(FmhaMask::IsMasking || - BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7b97d01fa4f..193333a093d 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" @@ -63,7 +64,10 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] static constexpr float OCP_FP8_SHIFT = 8.0f; @@ -292,7 +296,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || @@ -313,7 +317,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -343,7 +347,8 @@ struct BlockFmhaPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -534,7 +539,7 @@ struct BlockFmhaPipelineQRKSVSAsync #endif } } - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); @@ -566,7 +571,7 @@ struct BlockFmhaPipelineQRKSVSAsync }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto&&... args) { return variant.LogitsSinkMask(std::forward(args)...); @@ -747,7 +752,7 @@ struct BlockFmhaPipelineQRKSVSAsync reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); index_t seq_offset = [&]() { - if constexpr(!kHasSink) + if constexpr(!kHasStreamSink) return seqlen_k_start + i_total_loops * kN0; const bool in_sink_phase = (num_sink_loop > i_total_loops); @@ -845,7 +850,7 @@ struct BlockFmhaPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index c0d5ca291f0..ef505b07abb 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -69,7 +70,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; + static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -194,7 +198,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -228,7 +232,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -714,7 +719,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -748,7 +753,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, lse_acc); } diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 0670985e4f3..456b11a556e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -11,6 +11,19 @@ namespace ck_tile { +// Compile-time sink mode for FMHA forward kernels. +// Two independent sink mechanisms that can be combined: +// kStreamLLM: Sliding-window attention with initial-token sink (ICLR 2024, MIT HAN Lab) +// Controls KV tile loading schedule and per-pixel mask logic. +// kGptOss: GPT-OSS learnable softmax bias (OpenAI); initializes softmax m/l. +enum class FmhaSinkMode : int +{ + kNone = 0, // No sink — zero overhead + kStreamLLM = 1, // StreamingLLM sliding-window sink + kGptOss = 2, // GPT-OSS learnable softmax bias sink + kBoth = 3, // Both sinks simultaneously +}; + template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -38,7 +51,13 @@ struct TileFmhaTraits static constexpr auto QScaleEnum = QScaleEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + // Derived convenience constants + static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); + static constexpr bool kHasStreamSink = + (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode == FmhaSinkMode::kGptOss || kSinkMode == FmhaSinkMode::kBoth); }; template + FmhaSinkMode::kNone> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; @@ -110,9 +129,9 @@ template 1 or fwd training is running */ bool kIsPagedKV_, bool kDoFp8StaticQuant_, - index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ - bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ - bool kHasSink_ = false> + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaFwdPagedKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -127,7 +146,12 @@ struct TileFmhaFwdPagedKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); + static constexpr bool kHasStreamSink = + (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode == FmhaSinkMode::kGptOss || kSinkMode == FmhaSinkMode::kBoth); }; template + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -160,7 +184,12 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); + static constexpr bool kHasStreamSink = + (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode == FmhaSinkMode::kGptOss || kSinkMode == FmhaSinkMode::kBoth); }; template Date: Mon, 30 Mar 2026 03:01:56 +0000 Subject: [PATCH 02/25] fmha: remove unused sink constants from pipeline/problem/kernel layers After introducing FmhaSinkMode in d60c65deb, several re-exported constexpr members became dead declarations that no code actually reads: - kHasSink: removed from all pipeline structs, all kernel structs, and all three PipelineProblem structs. No site reads Pipeline::kHasSink or Problem::kHasSink. - kSinkMode: removed from all pipeline structs except pagedkv pipeline (pagedkv_kernel reads FmhaPipeline::kSinkMode for kernel name) and from fmha_fwd_kernel (only kHasGptOssSink is used there). splitkv and pagedkv kernels retain kSinkMode for kernel name generation. - kHasStreamSink: removed from fmha_fwd_kernel, splitkv_kernel, and pagedkv_kernel (none of them use it internally). - kHasGptOssSink: removed from splitkv_kernel and pagedkv_kernel. Retained in fmha_fwd_kernel where it guards sink_value computation. Problem layer retains kSinkMode (splitkv_kernel reads it via FmhaPipeline::Problem::kSinkMode), kHasStreamSink, and kHasGptOssSink (all pipelines read these via Problem::kHas*Sink). No functional change. Compilation and correctness tests (nsink, gsink) pass unchanged. --- .../include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 3 --- .../ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 3 --- .../ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 3 --- .../fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 1 - ...block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 2 -- .../fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 2 -- .../ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | 3 --- .../ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 2 -- .../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 2 -- .../pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 2 -- 10 files changed, 23 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 8de97eedc6e..0d100f37850 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -68,9 +68,6 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; - static constexpr bool kHasStreamSink = FmhaPipeline::kHasStreamSink; static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 76e635540df..0234c342d01 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -57,9 +57,6 @@ struct FmhaFwdPagedKVKernel static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; - static constexpr bool kHasStreamSink = FmhaPipeline::kHasStreamSink; - static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 55b9f63e79c..4fbf30e11b7 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -53,9 +53,6 @@ struct FmhaFwdSplitKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; - static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; - static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; - static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d0a61a81bd1..6fc4925a8d1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -60,7 +60,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 07461516348..4541cc7d488 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -58,8 +58,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 88cf2e3f57f..11cd8a0cd85 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -58,8 +58,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index cb4efbbef4e..e948849aec6 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -74,7 +74,6 @@ struct BlockFmhaPipelineProblem static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; - static constexpr bool kHasSink = Traits::kHasSink; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -190,7 +189,6 @@ struct BlockFmhaFwdPagedKVPipelineProblem static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; - static constexpr bool kHasSink = Traits::kHasSink; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -247,7 +245,6 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; - static constexpr bool kHasSink = Traits::kHasSink; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 50cd5b54a64..4e2f08ec844 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -65,8 +65,6 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto QScaleEnum = Problem::QScaleEnum; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 193333a093d..da44740457e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -64,8 +64,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index ef505b07abb..aacec1c22cd 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -70,8 +70,6 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; From 7b89aa3ecb4b294d90b905444db98e15bc965412 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:01:36 -0500 Subject: [PATCH 03/25] fmha: rename has_sink to has_stream_sink in runtime traits has_sink was ambiguous alongside has_gptoss_sink. Rename it to has_stream_sink to clearly indicate it controls the StreamLLM sliding-window sink, matching the kHasStreamSink naming on the compile-time side. Updated in fmha_fwd_traits/pagedkv/splitkv structs, fmha_fwd_runner, and all three codegen dispatch templates. --- .../example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 ++-- .../example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 2 +- .../ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py | 2 +- .../composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp | 6 +++--- .../example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c194a07f756..d5f5736d4b8 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -50,7 +50,7 @@ "both": "ck_tile::FmhaSinkMode::kBoth", } -# For backward compat dispatch check: map sink mode to (has_sink, has_gptoss_sink) +# For dispatch check: map sink mode to (has_stream_sink, has_gptoss_sink) SINK_MODE_DISPATCH_MAP = { "none": ("false", "false"), "stream": ("true", "false"), @@ -264,7 +264,7 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink_mode}>; return fmha_fwd_(s, a); diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 4de16a706a7..062db1b60ec 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -301,7 +301,7 @@ """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 787c00ff778..8f3583e72ca 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -151,7 +151,7 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink_mode}>; return fmha_fwd_pagedkv_(s, a); diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 0c6f0372b6a..3607b43d917 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1685,7 +1685,7 @@ struct fmha_fwd_traits bool has_dropout; quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; - bool has_sink = false; // StreamLLM sliding-window sink + bool has_stream_sink = false; // StreamLLM sliding-window sink bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; @@ -1705,7 +1705,7 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; - bool has_sink = false; // StreamLLM sliding-window sink + bool has_stream_sink = false; // StreamLLM sliding-window sink bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; @@ -1726,7 +1726,7 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant = false; - bool has_sink = false; // StreamLLM sliding-window sink + bool has_stream_sink = false; // StreamLLM sliding-window sink bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 3b3de84fd06..012d250a624 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1133,7 +1133,7 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; - traits.has_sink = mask.sink > 0 ? true : false; // StreamLLM sink + traits.has_stream_sink = mask.sink > 0 ? true : false; // StreamLLM sink traits.has_gptoss_sink = init_sink_value != 0; // GPT-OSS sink traits.has_lse = lse; From d6b11a91a02d9e08d4373fb3f26b5a7f5a141adb Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:20:18 -0500 Subject: [PATCH 04/25] fmha: deduplicate SINK_*_MAP by importing from fmha_fwd SINK_MODE_MAP, SINK_MODE_DISPATCH_MAP, and SINK_NAME_MAP were defined identically in fmha_fwd.py, fmha_fwd_splitkv.py, and fmha_pagedkv_prefill.py. Remove the duplicate definitions from the latter two and import them from fmha_fwd instead. --- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 24 +++---------------- .../codegen/ops/fmha_pagedkv_prefill.py | 24 +++---------------- 2 files changed, 6 insertions(+), 42 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 062db1b60ec..7569725b942 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -32,6 +32,9 @@ FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, + SINK_MODE_MAP, + SINK_MODE_DISPATCH_MAP, + SINK_NAME_MAP, ) @@ -40,27 +43,6 @@ "qr_nwarp_sshuffle": "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", } -SINK_MODE_MAP = { - "none": "ck_tile::FmhaSinkMode::kNone", - "stream": "ck_tile::FmhaSinkMode::kStreamLLM", - "gptoss": "ck_tile::FmhaSinkMode::kGptOss", - "both": "ck_tile::FmhaSinkMode::kBoth", -} - -SINK_MODE_DISPATCH_MAP = { - "none": ("false", "false"), - "stream": ("true", "false"), - "gptoss": ("false", "true"), - "both": ("true", "true"), -} - -SINK_NAME_MAP = { - "none": "_nsink", - "stream": "_ssink", - "gptoss": "_gsink", - "both": "_bsink", -} - FMHA_FWD_SPLITKV_KERNEL_BODY = """ #include diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 8f3583e72ca..1bd5c0787d9 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -31,30 +31,12 @@ FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, + SINK_MODE_MAP, + SINK_MODE_DISPATCH_MAP, + SINK_NAME_MAP, ) -SINK_MODE_MAP = { - "none": "ck_tile::FmhaSinkMode::kNone", - "stream": "ck_tile::FmhaSinkMode::kStreamLLM", - "gptoss": "ck_tile::FmhaSinkMode::kGptOss", - "both": "ck_tile::FmhaSinkMode::kBoth", -} - -SINK_MODE_DISPATCH_MAP = { - "none": ("false", "false"), - "stream": ("true", "false"), - "gptoss": ("false", "true"), - "both": ("true", "true"), -} - -SINK_NAME_MAP = { - "none": "_nsink", - "stream": "_ssink", - "gptoss": "_gsink", - "both": "_bsink", -} - FMHA_FWD_PAGEDKV_PIPELINE_MAP = { "qr_pagedkv": "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" } From bbef1af9c9385f54688e659975e33623f644a4e4 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:28:03 -0500 Subject: [PATCH 05/25] fmha: remove dead kHasSink constant kHasSink was superseded by kHasStreamSink and kHasGptOssSink when FmhaSinkMode was introduced. No code reads kHasSink after that change; all call sites were already using the more specific constants. Remove kHasSink from TileFmhaTraits / TileFmhaFwdSplitKVTraits / TileFmhaFwdPagedKVTraits (tile_fmha_traits.hpp), the batch-prefill pipeline forwarding declaration, and the three fmha_fwd_*traits_ structs in fmha_fwd.hpp. --- .../composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp | 3 --- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 1 - .../include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp | 4 ---- 3 files changed, 8 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 3607b43d917..d7ff69baebb 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1421,7 +1421,6 @@ struct fmha_fwd_traits_ static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasSink = (kSinkMode != ck_tile::FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || kSinkMode == ck_tile::FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = @@ -1538,7 +1537,6 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasSink = (kSinkMode != ck_tile::FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || kSinkMode == ck_tile::FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = @@ -1594,7 +1592,6 @@ struct fmha_fwd_splitkv_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasSink = (kSinkMode != ck_tile::FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || kSinkMode == ck_tile::FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index b4a2db32743..e1c452d86ef 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -293,7 +293,6 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; - static constexpr bool kHasSink = Problem::kHasSink; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 456b11a556e..42c5a172c0a 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -52,8 +52,6 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr FmhaSinkMode kSinkMode = kSinkMode_; - // Derived convenience constants - static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = @@ -147,7 +145,6 @@ struct TileFmhaFwdPagedKVTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = @@ -185,7 +182,6 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasSink = (kSinkMode != FmhaSinkMode::kNone); static constexpr bool kHasStreamSink = (kSinkMode == FmhaSinkMode::kStreamLLM || kSinkMode == FmhaSinkMode::kBoth); static constexpr bool kHasGptOssSink = From 0b4433bb4cd5955b4657fe993de05d8ab16c3829 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:43:45 -0500 Subject: [PATCH 06/25] fmha: remove redundant tile_fmha_traits include from fmha_fwd_kernel The include was added when kSinkMode (type FmhaSinkMode) was forwarded in FmhaFwdKernel, but kSinkMode was removed in the cleanup commit. The only remaining sink usage is kHasGptOssSink (a plain bool), which does not require FmhaSinkMode to be visible. tile_fmha_traits.hpp is already reachable transitively through the pipeline headers. --- .../include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0d100f37850..fc846690e28 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -8,7 +8,6 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include From 3c7970b2dd0bf7a3b8a14bf88f8ec54c59aa7d21 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:50:00 -0500 Subject: [PATCH 07/25] fmha: guard pagedkv sink_value with if constexpr(kHasGptOssSink) fmha_fwd_kernel.hpp was updated to use if constexpr(kHasGptOssSink) for sink_value computation, but fmha_fwd_pagedkv_kernel.hpp was missed and still used the runtime sink_ptr != nullptr check. Apply the same fix: add kHasGptOssSink forwarding constant and replace the runtime check with if constexpr, eliminating dead computation in no-sink kernels. --- .../ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 0234c342d01..e7284d277c2 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -56,7 +56,8 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; + static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -919,10 +920,12 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { From 395754e1ba71258bb44dda6a4a5c39aaa1936170 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 01:51:04 -0500 Subject: [PATCH 08/25] fmha: guard splitkv sink_value with if constexpr(kHasGptOssSink) Same fix as pagedkv_kernel: replace the runtime sink_ptr != nullptr check with if constexpr(kHasGptOssSink), eliminating dead computation in no-sink kernels. Add kHasGptOssSink forwarding constant alongside the existing kSinkMode. --- .../ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 4fbf30e11b7..4583ad9d08d 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -52,7 +52,8 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; + static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -621,10 +622,12 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_o_acc = 0; index_t kv_l2p_offset = 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { From eb966f091abe46ddc7dd681b333c621b94623bb3 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 02:37:17 -0500 Subject: [PATCH 09/25] fmha: unify kSinkMode access path and reduce tile_fmha_traits includes - Both pagedkv and splitkv kernels now read kSinkMode from FmhaPipeline::Problem::kSinkMode (previously pagedkv read from FmhaPipeline::kSinkMode directly). Also unify kHasGptOssSink to read from Problem layer in pagedkv kernel. - Use 'auto' for kSinkMode in the three PipelineProblem structs and the two pipeline structs (pagedkv, batch_prefill), eliminating the need to spell out FmhaSinkMode as a type name in those files. - Remove the now-redundant tile_fmha_traits.hpp include from block_fmha_pipeline_problem.hpp, block_fmha_fwd_pagedkv_pipeline_ qr_ks_vs.hpp, and block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp. --- .../ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 4 ++-- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 3 +-- .../pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 3 +-- .../ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index e7284d277c2..b216b0ea790 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -56,8 +56,8 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::kSinkMode; - static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; + static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; + static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index e1c452d86ef..a4d870195b1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" @@ -292,7 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; + static constexpr auto kSinkMode = Problem::kSinkMode; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 6fc4925a8d1..08297305237 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -59,7 +58,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = Problem::kSinkMode; + static constexpr auto kSinkMode = Problem::kSinkMode; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index e948849aec6..52b329cc257 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -6,7 +6,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" namespace ck_tile { @@ -73,7 +72,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -188,7 +187,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -244,7 +243,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr FmhaSinkMode kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; From eb25153d2f8d957f0689eeaae6f5c7ef945411af Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 02:53:15 -0500 Subject: [PATCH 10/25] fmha: replace kSinkMode enum comparison with bool flags in kernel GetName pagedkv and splitkv kernels used FmhaSinkMode enum values directly in GetName() to generate the _nsink/_ssink/_gsink/_bsink suffix, requiring tile_fmha_traits.hpp to be included. Replace with kHasStreamSink and kHasGptOssSink bool comparisons (already available via Problem layer), eliminating the need for FmhaSinkMode type visibility in these files. Remove tile_fmha_traits.hpp include from both kernel headers. Replace kSinkMode forwarding with kHasStreamSink forwarding. --- .../ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 5 ++--- .../ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index b216b0ea790..d63d9a36283 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -7,7 +7,6 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -56,7 +55,7 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; + static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; @@ -105,7 +104,7 @@ struct FmhaFwdPagedKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + - (kSinkMode == FmhaSinkMode::kNone ? "_nsink" : kSinkMode == FmhaSinkMode::kStreamLLM ? "_ssink" : kSinkMode == FmhaSinkMode::kGptOss ? "_gsink" : "_bsink"); + (!kHasStreamSink && !kHasGptOssSink ? "_nsink" : kHasStreamSink && !kHasGptOssSink ? "_ssink" : !kHasStreamSink && kHasGptOssSink ? "_gsink" : "_bsink"); #undef _SS_ #undef _TS_ // clang-format on diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 4583ad9d08d..d7869ada0bb 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -7,7 +7,6 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -52,7 +51,7 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr FmhaSinkMode kSinkMode = FmhaPipeline::Problem::kSinkMode; + static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; @@ -105,7 +104,7 @@ struct FmhaFwdSplitKVKernel (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + - (kSinkMode == FmhaSinkMode::kNone ? "_nsink" : kSinkMode == FmhaSinkMode::kStreamLLM ? "_ssink" : kSinkMode == FmhaSinkMode::kGptOss ? "_gsink" : "_bsink"); + (!kHasStreamSink && !kHasGptOssSink ? "_nsink" : kHasStreamSink && !kHasGptOssSink ? "_ssink" : !kHasStreamSink && kHasGptOssSink ? "_gsink" : "_bsink"); #undef _SS_ #undef _TS_ // clang-format on From ca10670751b2673d88c7b7b7ec88c2f4a32130ee Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 02:54:58 -0500 Subject: [PATCH 11/25] fmha: remove redundant tile_fmha_traits include from pipeline files All five pipeline files include tile_fmha_traits.hpp but only use kHasStreamSink and kHasGptOssSink (plain bools forwarded from Problem), never FmhaSinkMode directly. The include was unnecessary; remove it. --- .../block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp | 1 - .../fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp | 1 - .../ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 1 - .../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp | 1 - .../fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp | 1 - 5 files changed, 5 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 4541cc7d488..0360867edf2 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 11cd8a0cd85..69c16d117bc 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm_gfx11_utils.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 4e2f08ec844..598b9efa80c 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/cast_tile_mx.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index da44740457e..41d460f83b6 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aacec1c22cd..04fbda1bdc0 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" From 66b747492dd8eadf2e7aefa651c9bc68e60ab188 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 03:02:03 -0500 Subject: [PATCH 12/25] fmha: revert stray changes in block_fmha_pipeline_qr_ks_vs Restore the original code order and comments that were inadvertently modified in the FmhaSinkMode commit: - Revert O normalization moved before LSE store back to after it - Restore '// store lse' and '// finally, O' comments These changes belong to a different PR. --- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 48 +++++++++---------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 598b9efa80c..f4b151c0e8d 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -976,31 +976,7 @@ struct BlockFmhaPipelineQRKSVS } } while(++i_total_loops < num_total_loop); - // finally, O -- normalize BEFORE storing LSE to avoid VGPR aliasing: - // the lse tile creation (below) may reuse o_acc VGPRs; normalizing first - // ensures o_acc is fully consumed before those registers can be reused. - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - // When bias carries -inf masks the denominator can be zero; guard the normalization - // so we do not divide by zero after a fully masked row. - if constexpr(FmhaMask::IsMasking || - BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; - }); - }); - - // store lse -- AFTER O normalization to prevent VGPR reuse corruption + // store lse if constexpr(kStoreLSE) { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); @@ -1042,6 +1018,28 @@ struct BlockFmhaPipelineQRKSVS store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + // When bias carries -inf masks the denominator can be zero; guard the normalization + // so we do not divide by zero after a fully masked row. + if constexpr(FmhaMask::IsMasking || + BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); return o_acc; From 6375bf4cb188dafe14384b2948890dfc6e1c546d Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 03:10:33 -0500 Subject: [PATCH 13/25] fmha: apply clang-format alignment fixes --- .../ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 2 +- .../block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp | 2 +- .../pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 2 +- .../ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index d63d9a36283..5ce7d2ab801 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -919,7 +919,7 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; - const float sink_value = [&]() { + const float sink_value = [&]() { if constexpr(kHasGptOssSink) return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; else diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index a4d870195b1..af0bc90915f 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -291,7 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; - static constexpr auto kSinkMode = Problem::kSinkMode; + static constexpr auto kSinkMode = Problem::kSinkMode; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 08297305237..9f281820b9e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -58,7 +58,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr auto kSinkMode = Problem::kSinkMode; + static constexpr auto kSinkMode = Problem::kSinkMode; static constexpr bool kHasStreamSink = Problem::kHasStreamSink; static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 52b329cc257..e871b2c1289 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -72,7 +72,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr auto kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -187,7 +187,7 @@ struct BlockFmhaFwdPagedKVPipelineProblem static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr auto kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; @@ -243,7 +243,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr auto kSinkMode = Traits::kSinkMode; + static constexpr auto kSinkMode = Traits::kSinkMode; static constexpr bool kHasStreamSink = Traits::kHasStreamSink; static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; From 22be529a6ca69080baabc8db14b080a4b0e515e1 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 31 Mar 2026 03:55:50 -0500 Subject: [PATCH 14/25] fmha: validate --sink argument in generate.py --- .../composablekernel/example/ck_tile/01_fmha/generate.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/generate.py b/projects/composablekernel/example/ck_tile/01_fmha/generate.py index 5e86cf83f18..6e6ed00944a 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/generate.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/generate.py @@ -189,7 +189,14 @@ def list_blobs( filter_list = args.filter.split(",") filter_list.extend([""] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(",")] - sink_modes = tuple(args.sink.split(",")) + sink_modes = tuple(s.strip() for s in args.sink.split(",")) + valid_sink_modes = {"none", "stream", "gptoss", "both"} + invalid_sink_modes = set(sink_modes) - valid_sink_modes + if invalid_sink_modes: + parser.error( + f"Invalid sink mode(s): {sorted(invalid_sink_modes)}. " + f"Valid values are: {sorted(valid_sink_modes)}" + ) if args.list_blobs is not None: list_blobs( From 1d1f7a005d3314e5e73aef2707781d5911501714 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Mon, 6 Apr 2026 19:17:05 -0500 Subject: [PATCH 15/25] [CK Tile] Gate sink smoke tests behind opt-in flags in smoke_test_fwd.sh run_sink_mask_tests (StreamLLM) and run_sink_init_tests (GPT-OSS) require kernel instances compiled with FMHA_FWD_SINK_MODES=stream/gptoss respectively. Running them unconditionally against a default build (FMHA_FWD_SINK_MODES=none) causes "not supported yet" failures for all cases. Gate each behind a new CLI flag (-m / -g), consistent with the existing -s (splitkv) and -a (appendkv) opt-in pattern. Usage comments document the required build configuration alongside each flag. --- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 47 ++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 1e9942a6e1b..6578055d952 100755 --- a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -28,10 +28,33 @@ COMMON_ARGS='-v=1 -warmup=0 -repeat=1' TEST_SPLITKV=0 TEST_APPENDKV=0 -# options: -# -s: run splitkv tests -# -a: run appendkv tests -while getopts ":sa" opt; do +TEST_STREAM_SINK=0 +TEST_GPTOSS_SINK=0 +# Usage: +# bash smoke_test_fwd.sh [options] +# +# Options: +# -s run splitkv tests +# -a run appendkv tests +# -m run StreamLLM sliding-window sink mask tests +# requires the binary to be built with FMHA_FWD_SINK_MODES containing "stream", e.g.: +# cmake -DFMHA_FWD_SINK_MODES="none,stream" ... +# cmake --build . --target tile_example_fmha_fwd +# bash smoke_test_fwd.sh -m +# -g run GPT-OSS sink (init_sink) tests +# requires the binary to be built with FMHA_FWD_SINK_MODES containing "gptoss", e.g.: +# cmake -DFMHA_FWD_SINK_MODES="none,gptoss" ... +# cmake --build . --target tile_example_fmha_fwd +# bash smoke_test_fwd.sh -g +# +# Examples: +# bash smoke_test_fwd.sh # default: base tests only +# bash smoke_test_fwd.sh -s # also run splitkv tests +# bash smoke_test_fwd.sh -a # also run appendkv tests +# bash smoke_test_fwd.sh -m # also run StreamLLM sink tests +# bash smoke_test_fwd.sh -g # also run GPT-OSS sink tests +# bash smoke_test_fwd.sh -s -a -m -g # run all tests +while getopts ":samg" opt; do case "${opt}" in s) TEST_SPLITKV=1 @@ -39,6 +62,12 @@ while getopts ":sa" opt; do a) TEST_APPENDKV=1 ;; + m) + TEST_STREAM_SINK=1 + ;; + g) + TEST_GPTOSS_SINK=1 + ;; *) ;; esac @@ -300,8 +329,14 @@ run_padding_smoke_tests run_padding_basic_boundary_tests run_fp8bf16_tests run_fp8fp32_tests -run_sink_mask_tests -run_sink_init_tests + +if [ $TEST_STREAM_SINK -eq 1 ] ; then + run_sink_mask_tests +fi + +if [ $TEST_GPTOSS_SINK -eq 1 ] ; then + run_sink_init_tests +fi if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests From 368851e9c78e5e0db08bc4d2312161924db304a5 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Sun, 12 Apr 2026 21:35:45 -0500 Subject: [PATCH 16/25] [CK Tile][FMHA] Address poyenc code review comments on PR #6057 Issue 1: Remove dead else branch in block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp. The inner if constexpr(kHasGptOssSink) at the early-exit LSE path was already inside an outer block guarded by the same condition, making the else unreachable. Issue 2: Replace nested ternary sink name suffix with FmhaSinkNameSuffix<> helper. Add constexpr helper to tile_fmha_traits.hpp that mirrors the Python SINK_NAME_MAP, and use it in fmha_fwd_splitkv_kernel.hpp and fmha_fwd_pagedkv_kernel.hpp. Issue 3: Document intentionally-ignored sink_modes in fmha_batch_prefill.py and fmha_fwd_appendkv.py. These APIs do not yet support sink kernel variants; the parameter exists only for handler signature uniformity. Issue 4: Replace duplicated `if api == "bwd"` branches in generate.py with a single _APIS_WITHOUT_SINK constant, with a comment explaining the exclusion. Issue 5: Print informational messages in smoke_test_fwd.sh when StreamLLM or GPT-OSS sink tests are skipped, so CI jobs without -m/-g flags get visible feedback. Nit 1: Add FmhaSinkModeHelper template to tile_fmha_traits.hpp. Replace the six copies of the kHasStreamSink/kHasGptOssSink two-line derivation in tile_fmha_traits.hpp and fmha_fwd.hpp with references to this single helper. Nit 2: Add comment on the float != 0 comparison for has_gptoss_sink in fmha_fwd_runner.hpp clarifying that callers must pass exactly 0.0 to disable. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 4 ++ .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 4 ++ .../example/ck_tile/01_fmha/fmha_fwd.hpp | 18 +++---- .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 4 +- .../example/ck_tile/01_fmha/generate.py | 8 ++- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 4 ++ .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 3 +- .../fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 3 +- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 12 ++--- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 52 +++++++++++++------ 10 files changed, 74 insertions(+), 38 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 1eb22bb9871..7863ab3fbae 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -830,6 +830,9 @@ def write_blobs( mask_impl, sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: batch_prefill does not yet support + # StreamLLM/GPT-OSS sink kernels. The parameter exists only for API uniformity + # with other fwd handlers. Non-"none" modes are silently treated as "none". api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -845,6 +848,7 @@ def list_blobs( mask_impl, sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: see write_blobs for rationale. with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 7f1851380cb..505640514ce 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -495,6 +495,9 @@ def write_blobs( mask_impl, sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: appendkv does not support sink kernels. + # The parameter exists only for API uniformity with other fwd handlers. + # Non-"none" modes are silently treated as "none". api_pool, kernels = get_fwd_appendkv_blobs( targets, kernel_filter, receipt, mask_impl, optdim_list ) @@ -512,6 +515,7 @@ def list_blobs( mask_impl, sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: see write_blobs for rationale. with file_path.open("a") as f: _, kernels = get_fwd_appendkv_blobs( targets, kernel_filter, receipt, mask_impl, optdim_list diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index e3f0b98d9b4..f57dab750f9 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1427,10 +1427,10 @@ struct fmha_fwd_traits_ static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || - kSinkMode == ck_tile::FmhaSinkMode::kBoth); + static constexpr bool kHasStreamSink = + ck_tile::FmhaSinkModeHelper::kHasStreamSink; static constexpr bool kHasGptOssSink = - (kSinkMode == ck_tile::FmhaSinkMode::kGptOss || kSinkMode == ck_tile::FmhaSinkMode::kBoth); + ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; static constexpr bool kHasGptOssSink = - (kSinkMode == ck_tile::FmhaSinkMode::kGptOss || kSinkMode == ck_tile::FmhaSinkMode::kBoth); + ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template @@ -1598,10 +1598,10 @@ struct fmha_fwd_splitkv_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasStreamSink = (kSinkMode == ck_tile::FmhaSinkMode::kStreamLLM || - kSinkMode == ck_tile::FmhaSinkMode::kBoth); + static constexpr bool kHasStreamSink = + ck_tile::FmhaSinkModeHelper::kHasStreamSink; static constexpr bool kHasGptOssSink = - (kSinkMode == ck_tile::FmhaSinkMode::kGptOss || kSinkMode == ck_tile::FmhaSinkMode::kBoth); + ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 012d250a624..820a82836c5 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1134,7 +1134,9 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_stream_sink = mask.sink > 0 ? true : false; // StreamLLM sink - traits.has_gptoss_sink = init_sink_value != 0; // GPT-OSS sink + // GPT-OSS sink: callers must pass exactly 0 (not epsilon) to disable. + // This is an integer-exact float comparison; any non-zero value enables the sink. + traits.has_gptoss_sink = init_sink_value != 0; traits.has_lse = lse; if constexpr(std::is_same_v>) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/generate.py b/projects/composablekernel/example/ck_tile/01_fmha/generate.py index 6e6ed00944a..88492086fbe 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/generate.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/generate.py @@ -37,6 +37,10 @@ class HandlerId(IntEnum): ) assert 0 < len(handlers) +# APIs that do not support sink modes (sink_modes is not forwarded to their handlers). +# These APIs predate the StreamLLM/GPT-OSS sink split and have no sink kernel variants. +_APIS_WITHOUT_SINK = {"bwd"} + def write_blobs( targets: List[str], @@ -57,7 +61,7 @@ def write_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - if api == "bwd": + if api in _APIS_WITHOUT_SINK: handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) else: handler( @@ -90,7 +94,7 @@ def list_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - if api == "bwd": + if api in _APIS_WITHOUT_SINK: handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) else: handler( diff --git a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 6578055d952..9784b9b2096 100755 --- a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -332,10 +332,14 @@ run_fp8fp32_tests if [ $TEST_STREAM_SINK -eq 1 ] ; then run_sink_mask_tests +else + echo ">>> Skipping StreamLLM sink tests (use -m to enable)" fi if [ $TEST_GPTOSS_SINK -eq 1 ] ; then run_sink_init_tests +else + echo ">>> Skipping GPT-OSS sink tests (use -g to enable)" fi if [ $TEST_APPENDKV -eq 1 ] ; then diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 5ce7d2ab801..31695ac4db1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -104,7 +105,7 @@ struct FmhaFwdPagedKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + - (!kHasStreamSink && !kHasGptOssSink ? "_nsink" : kHasStreamSink && !kHasGptOssSink ? "_ssink" : !kHasStreamSink && kHasGptOssSink ? "_gsink" : "_bsink"); + ck_tile::FmhaSinkNameSuffix(); #undef _SS_ #undef _TS_ // clang-format on diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index d7869ada0bb..bc979a21576 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -104,7 +105,7 @@ struct FmhaFwdSplitKVKernel (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + - (!kHasStreamSink && !kHasGptOssSink ? "_nsink" : kHasStreamSink && !kHasGptOssSink ? "_ssink" : !kHasStreamSink && kHasGptOssSink ? "_gsink" : "_bsink"); + ck_tile::FmhaSinkNameSuffix(); #undef _SS_ #undef _TS_ // clang-format on diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index af0bc90915f..04d7b415f2a 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -565,15 +565,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if constexpr(kHasGptOssSink) - { - const SMPLComputeDataType sink_lse = sink_v * scale_s; - set_tile(lse, sink_lse); - } - else - { - set_tile(lse, -numeric::infinity()); - } + // Already inside if constexpr(kHasGptOssSink); sink_v is always valid. + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 42c5a172c0a..bc2684f18ab 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -24,6 +24,34 @@ enum class FmhaSinkMode : int kBoth = 3, // Both sinks simultaneously }; +// Helper that derives kHasStreamSink / kHasGptOssSink from a FmhaSinkMode value. +// Use as a base or member in traits structs to avoid repeating the same two-line +// derivation in every traits type. +template +struct FmhaSinkModeHelper +{ + static constexpr bool kHasStreamSink = + (kSinkMode_ == FmhaSinkMode::kStreamLLM || kSinkMode_ == FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode_ == FmhaSinkMode::kGptOss || kSinkMode_ == FmhaSinkMode::kBoth); +}; + +// Returns the kernel name suffix for a given sink mode combination. +// Mirrors the Python-side SINK_NAME_MAP: none->"_nsink", stream->"_ssink", +// gptoss->"_gsink", both->"_bsink". +template +constexpr const char* FmhaSinkNameSuffix() +{ + if constexpr(!HasStream && !HasGptOss) + return "_nsink"; + else if constexpr(HasStream && !HasGptOss) + return "_ssink"; + else if constexpr(!HasStream && HasGptOss) + return "_gsink"; + else + return "_bsink"; +} + template ::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template Date: Mon, 13 Apr 2026 20:29:02 -0500 Subject: [PATCH 17/25] [CK Tile][FMHA] Fix code alignment for sink-related fields --- .../example/ck_tile/01_fmha/fmha_fwd.hpp | 18 ++++++------------ .../ck_tile/01_fmha/fmha_fwd_runner.hpp | 4 ++-- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 18 +++++++++--------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index f57dab750f9..b23d047f956 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1427,10 +1427,8 @@ struct fmha_fwd_traits_ static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasStreamSink = - ck_tile::FmhaSinkModeHelper::kHasStreamSink; - static constexpr bool kHasGptOssSink = - ck_tile::FmhaSinkModeHelper::kHasGptOssSink; + static constexpr bool kHasStreamSink = ck_tile::FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; - static constexpr bool kHasGptOssSink = - ck_tile::FmhaSinkModeHelper::kHasGptOssSink; + static constexpr bool kHasStreamSink = ck_tile::FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template @@ -1598,10 +1594,8 @@ struct fmha_fwd_splitkv_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasStreamSink = - ck_tile::FmhaSinkModeHelper::kHasStreamSink; - static constexpr bool kHasGptOssSink = - ck_tile::FmhaSinkModeHelper::kHasGptOssSink; + static constexpr bool kHasStreamSink = ck_tile::FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 820a82836c5..0eeb4470ba5 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1136,8 +1136,8 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_stream_sink = mask.sink > 0 ? true : false; // StreamLLM sink // GPT-OSS sink: callers must pass exactly 0 (not epsilon) to disable. // This is an integer-exact float comparison; any non-zero value enables the sink. - traits.has_gptoss_sink = init_sink_value != 0; - traits.has_lse = lse; + traits.has_gptoss_sink = init_sink_value != 0; + traits.has_lse = lse; if constexpr(std::is_same_v>) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index bc2684f18ab..156c4eb53e0 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -78,10 +78,10 @@ struct TileFmhaTraits static constexpr bool kHasDropout = kHasDropout_; static constexpr auto QScaleEnum = QScaleEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; - static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr FmhaSinkMode kSinkMode = kSinkMode_; - static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; - static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; - static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template ::kHasStreamSink; - static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template Date: Wed, 15 Apr 2026 03:30:25 -0500 Subject: [PATCH 18/25] [CK Tile][FMHA] Remove stray compiler workaround from sink mode refactor Remove an unrelated qr_async fallback branch for IsMasking=true + (no_bias+dropout) or (alibi+no_dropout) combinations that was accidentally mixed into the FmhaSinkMode refactor commit. --- .../example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 22b7fd57bfa..827bbf64d60 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1066,14 +1066,6 @@ def get_pipelines( # TODO: rocm 6.2 compiler problem if using qr_async for bias case pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - elif mask not in ("s_no", "no") and ( - (bias == "no" and dropout == "t") - or (bias == "alibi" and dropout == "f") - ): - # TODO: compiler problem with qr_async for IsMasking=true + - # (no_bias+dropout) or (alibi+ndropout) combinations - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip From 973e0414e6abb3ffe3e03e024d4cd3de960a090b Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Wed, 15 Apr 2026 03:44:30 -0500 Subject: [PATCH 19/25] =?UTF-8?q?[CK=20Tile][FMHA]=20Fix=20F=5Fsink=3D"f"?= =?UTF-8?q?=20=E2=86=92=20"none"=20after=20FmhaSinkMode=20refactor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After the FmhaSinkMode enum refactor, valid values for F_sink are "none"/"stream"/"gptoss"/"both". Two stale "f" values were left behind: - fp8bf16 qr_async_trload_v3 pipeline construction (F_sink="f") - receipt=2 kernel filter condition (F_sink == "f") Both caused KeyError at codegen time. --- .../example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 827bbf64d60..183c05a5ef6 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1173,7 +1173,7 @@ def get_pipelines( ["no", "pertensor"], ["no", "causal"], ): - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="none")) # fmt: skip elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: # no need dropout kernels @@ -1412,7 +1412,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" - cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_sink == "none" # FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled. cond &= kernel_ctx.pipeline.F_logits == "f" cond &= kernel_ctx.pipeline.F_lse == "t" From c17001a5d79c147ade630424692d45060057a225 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 16 Apr 2026 05:03:21 -0500 Subject: [PATCH 20/25] [CK] Skip fp16 dropout d256/d24 batch tests for compiler VGPR aliasing bug (ROCm 7.1.x) The AMDGPU compiler on ROCm 7.1.x miscompiles fp16 dropout kernels with d256 tile under high register pressure: Philox RNG VGPRs (ph_seed, ph_head_offset) get aliased with other live data, producing corrupted output. Fixed in ROCm 7.2. - config.hpp: Add CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE macro (active when HIP_VERSION_MAJOR==7 && HIP_VERSION_MINOR==1) - test_fmha_fwd.cpp: GTEST_SKIP the 4 affected AllLong cases (fp16, batch, hdim_q=256, hdim_v=24, p_drop>0) on ROCm 7.1.x --- .../composablekernel/include/ck_tile/core/config.hpp | 12 ++++++++++++ .../test/ck_tile/fmha/test_fmha_fwd.cpp | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/projects/composablekernel/include/ck_tile/core/config.hpp b/projects/composablekernel/include/ck_tile/core/config.hpp index 036e241c95e..f034ec321bf 100644 --- a/projects/composablekernel/include/ck_tile/core/config.hpp +++ b/projects/composablekernel/include/ck_tile/core/config.hpp @@ -209,6 +209,18 @@ #endif #endif +// workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm 7.1.x) +// Philox RNG VGPR parameters get aliased under high register pressure (d256 tile). +// fp16 is affected; bf16 is not (different type conversion codegen path). +// Fixed in ROCm 7.2. +#ifndef CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE +#if(HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR == 1) +#define CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif diff --git a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp index c2a90360d98..c7c62531c33 100644 --- a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -218,6 +218,14 @@ TEST_P(AllLong, DataTypeConfig) hdim_q = hdim_q_ == -1 ? hdim_q : hdim_q_; hdim_v = hdim_v_ == -1 ? hdim_v : hdim_v_; +#if CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE + if constexpr(std::is_same_v) + { + if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0) + GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch — compiler VGPR aliasing bug (ROCm 7.1.x)"; + } +#endif + auto result = fmha_fwd_run(mode, batch, nhead, From 841bfa0ff2a61835f90d41039b1d7305d5e4c45e Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 16 Apr 2026 05:04:11 -0500 Subject: [PATCH 21/25] [CK] Restrict fp16 dropout d256/d24 batch skip to gfx950 The VGPR aliasing miscompile only affects gfx950; add ck_tile::is_gfx95_supported() runtime check to the GTEST_SKIP condition. --- .../composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp index c7c62531c33..8fa729a6fa7 100644 --- a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -221,8 +221,9 @@ TEST_P(AllLong, DataTypeConfig) #if CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE if constexpr(std::is_same_v) { - if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0) - GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch — compiler VGPR aliasing bug (ROCm 7.1.x)"; + if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0 && + ck_tile::is_gfx95_supported()) + GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch gfx950 — compiler VGPR aliasing bug (ROCm 7.1.x)"; } #endif From def192e66cbfd7f519dd5f502d6e2548b1d87149 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 16 Apr 2026 05:19:44 -0500 Subject: [PATCH 22/25] [CK] Fix clang-format: wrap long GTEST_SKIP string literal --- projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp index 63013c38377..8cf6c22c02f 100644 --- a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -223,7 +223,8 @@ TEST_P(AllLong, DataTypeConfig) { if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0 && ck_tile::is_gfx95_supported()) - GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch gfx950 — compiler VGPR aliasing bug (ROCm 7.1.x)"; + GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch gfx950 — compiler VGPR aliasing " + "bug (ROCm 7.1.x)"; } #endif From 22d8f5240930ecb5e51696c47f58f2891672709e Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 16 Apr 2026 19:18:56 -0500 Subject: [PATCH 23/25] [CK] Remove gfx95 guard from fp16 dropout d256/d24 batch skip (ROCm 7.1.x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The is_gfx95_supported() check was redundant — the VGPR aliasing miscompile affects all gfx9 targets on ROCm 7.1.x, not only gfx950. --- .../composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp index 8cf6c22c02f..d440b2d9c1f 100644 --- a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -221,9 +221,8 @@ TEST_P(AllLong, DataTypeConfig) #if CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE if constexpr(std::is_same_v) { - if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0 && - ck_tile::is_gfx95_supported()) - GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch gfx950 — compiler VGPR aliasing " + if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0) + GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch — compiler VGPR aliasing " "bug (ROCm 7.1.x)"; } #endif From f6e785f9fd2c38542bfc8c3704caa2f462ef673c Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Thu, 16 Apr 2026 20:57:32 -0500 Subject: [PATCH 24/25] [CK] Add fp16 dropout d256 batch skip for ROCm 7.1.x to Dropout test The existing 7.1.x workaround was only in AllLong (gated by env var), leaving the always-on Dropout test exposed to the same VGPR aliasing miscompile bug. --- .../composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp index d440b2d9c1f..b22ef209077 100644 --- a/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp +++ b/projects/composablekernel/test/ck_tile/fmha/test_fmha_fwd.cpp @@ -610,6 +610,14 @@ TEST_P(Dropout, DataTypeConfig) auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs; auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; +#if CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE + if constexpr(std::is_same_v) + { + if(hdim_q > 128 && mode == mode_enum::batch) + GTEST_SKIP() << "Skipped: fp16 dropout d256 batch — compiler VGPR aliasing " + "bug (ROCm 7.1.x)"; + } +#endif #if CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE if constexpr(std::is_same_v) { From 77b26adec1c246d518d401c8522ea6241d02af9f Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Fri, 17 Apr 2026 01:42:39 -0500 Subject: [PATCH 25/25] [fmha] Skip fp16 d256 batch bias=e mask dropout on ROCm 7.1.x in smoke_test_fwd.sh ROCm 7.1.x has a compiler VGPR aliasing miscompile that causes wrong results for fp16 d256 batch mode with bias=e, mask, and dropout. Detect the ROCm version at runtime and skip the affected run_exe invocation when on 7.1.x. --- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 9784b9b2096..8d73957a399 100755 --- a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -17,6 +17,16 @@ fi export CK_WARMUP=0 export CK_REPEAT=1 +# Detect ROCm 7.1.x: compiler VGPR aliasing miscompile affects fp16 d256 batch+dropout+bias+mask +_hip_ver=$(hipcc --version 2>/dev/null | awk '/HIP version/{print $NF}') +_hip_major=${_hip_ver%%.*} +_hip_minor=${_hip_ver#*.}; _hip_minor=${_hip_minor%%.*} +SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT=0 +if [ "${_hip_major}" = "7" ] && [ "${_hip_minor}" = "1" ]; then + SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT=1 + echo ">>> ROCm 7.1.x detected: skipping fp16 d256 batch bias=e mask dropout cases (VGPR aliasing miscompile)" +fi + CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} rm -f $CURR_FAILS_FILE touch $CURR_FAILS_FILE @@ -110,7 +120,11 @@ run_fp16_bf16_tests() { run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # Skip fp16 d256 batch bias=e mask dropout on ROCm 7.1.x (VGPR aliasing miscompile bug) + if ! { [ "$SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT" = "1" ] && [ "$prec" = "fp16" ] && \ + [ "$mode" = "0" ] && [ "$hdim" = "256" ] && [ "$bias" = "e" ] && [ "$p_drop" != "0.0" ]; }; then run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + fi run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS