diff --git a/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt b/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e0..8153982f57d 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt +++ b/projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt @@ -43,13 +43,23 @@ set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS} list(JOIN INST_TARGETS , FMHA_TARGETS_ARG) string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +set(FMHA_FWD_OPTDIM "32,64,80,128,256" CACHE STRING + "comma-separated list of hdim values to optimize for") +set(FMHA_FWD_FILTER "" CACHE STRING + "fnmatch filter pattern for fwd kernel instances. Empty = no filter.") +set(FMHA_FWD_SINK_MODES "none" CACHE STRING + "comma-separated list of sink modes to generate (none,stream,gptoss,both). Default: none.") + set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} - --optdim 32,64,80,128,256 - # --filter fmha_fwd... + --optdim ${FMHA_FWD_OPTDIM} + --sink ${FMHA_FWD_SINK_MODES} ) +if(FMHA_FWD_FILTER) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter ${FMHA_FWD_FILTER}) +endif() set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --targets ${FMHA_TARGETS_ARG} @@ -98,6 +108,7 @@ add_custom_command( --output_dir ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${CODE_GEN_SCRIPTS} COMMENT "Generate CK Tile FMHA FWD kernels" + VERBATIM ) add_custom_command( diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 35e8c1be49d..7863ab3fbae 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -828,7 +828,11 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: batch_prefill does not yet support + # StreamLLM/GPT-OSS sink kernels. The parameter exists only for API uniformity + # with other fwd handlers. Non-"none" modes are silently treated as "none". api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -842,7 +846,9 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: see write_blobs for rationale. with file_path.open("a") as f: _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index c64a19104e6..183c05a5ef6 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -42,6 +42,30 @@ "mxfp4": 4, } +# Maps FmhaSinkMode string values to C++ FmhaSinkMode enum +SINK_MODE_MAP = { + "none": "ck_tile::FmhaSinkMode::kNone", + "stream": "ck_tile::FmhaSinkMode::kStreamLLM", + "gptoss": "ck_tile::FmhaSinkMode::kGptOss", + "both": "ck_tile::FmhaSinkMode::kBoth", +} + +# For dispatch check: map sink mode to (has_stream_sink, has_gptoss_sink) +SINK_MODE_DISPATCH_MAP = { + "none": ("false", "false"), + "stream": ("true", "false"), + "gptoss": ("false", "true"), + "both": ("true", "true"), +} + +# Legacy sink name suffix for kernel files +SINK_NAME_MAP = { + "none": "_nsink", + "stream": "_ssink", + "gptoss": "_gsink", + "both": "_bsink", +} + K0_MAX_SUBMAX_MAP = { 32: 32, 48: 48, @@ -104,7 +128,7 @@ {F_qscale}, {F_occupancy}, {F_skip}, - {F_sink}>; + {F_sink_mode}>; using fmha_variant = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -141,7 +165,7 @@ using trait = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink_mode}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -248,9 +272,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink_mode}>; return fmha_fwd_(s, a); }} """ @@ -414,7 +438,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false - F_sink: str # true/false + F_sink: str # "none" / "stream" / "gptoss" / "both" F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -485,10 +509,7 @@ def pad_name() -> str: n += "_trload" else: n += "_ntrload" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -578,7 +599,9 @@ def has_traits(node) -> bool: F_trload=BOOL_MAP[trait.tr_load], F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], F_qscale=QSCALE_MAP[trait.qscale], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, @@ -728,7 +751,7 @@ def render(self) -> str: F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_kernel=self._get_cpp_kernel_class_name(self.F_pipeline.tag), F_kargs_creator=self._get_cpp_kargs_creator_func_name(self.F_pipeline.tag), - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], ) @property @@ -1001,7 +1024,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: # support this in future @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later @@ -1017,7 +1040,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1031,7 +1054,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1055,7 +1078,7 @@ def get_pipelines( ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], - ["f", "t"], + sink_modes, ): if hdim == 64: pipelines.append(FmhaFwdPipeline("qr", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip @@ -1111,10 +1134,10 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = KernelComponentFactoryGfx9.get_pipelines( - dtype, hdim, hdim_v, receipt, mask_impl + dtype, hdim, hdim_v, receipt, mask_impl, sink_modes ) if dtype in cls._DT_FP16_BF16: qscale = "no" @@ -1125,7 +1148,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -1137,11 +1160,12 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip - # # qr_async_trload_v3 bf16/fp16 not ready - # if (hdim, hdim_v) == (128, 128): - # for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): - # pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", - # F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + # qr_async_trload_v3 only supports hdim=hdim_v=128 for now + if (hdim, hdim_v) == (128, 128): + # qr_async_trload_v3 only supports (generic) causal mask + for logits, mask in itertools.product(["t", "f"], ["no", "causal"]): + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", + F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="none")) # fmt: skip elif dtype in cls._DT_FP8BF16: # qr_async_trload_v3 only supports (generic) causal mask for logits, qscale, mask in itertools.product( @@ -1149,7 +1173,7 @@ def get_pipelines( ["no", "pertensor"], ["no", "causal"], ): - pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="none")) # fmt: skip elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4: # no need dropout kernels @@ -1160,7 +1184,7 @@ def get_pipelines( ["mx"], get_mask_map(mask_impl).keys(), ["no"], - ["f", "t"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr", "col", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, "f", "f", sink)) # fmt: skip @@ -1229,7 +1253,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in cls._DT_FP16_BF16: @@ -1241,7 +1265,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): # Keep only ttff/tttt for gfx11: ffff path is often similar or worse # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1304,7 +1328,7 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]: @classmethod def get_pipelines( - cls, dtype, hdim, hdim_v, receipt, mask_impl + cls, dtype, hdim, hdim_v, receipt, mask_impl, sink_modes=("none",) ) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in cls._DT_FP16_BF16: @@ -1316,7 +1340,7 @@ def get_pipelines( ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], + sink_modes, ): # pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip @@ -1325,15 +1349,16 @@ def get_pipelines( pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32: # no need lse/dropout kernels - for logits, qscale, mask, bias in itertools.product( + for logits, qscale, mask, bias, sink in itertools.product( ["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], + sink_modes, ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip return pipelines @@ -1387,7 +1412,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" - cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_sink == "none" # FlashAttention direct fwd wrappers always use softcap disabled and LSE enabled. cond &= kernel_ctx.pipeline.F_logits == "f" cond &= kernel_ctx.pipeline.F_lse == "t" @@ -1403,7 +1428,7 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: cond &= kernel_ctx.pipeline.F_bias in ["no", "alibi"] cond &= kernel_ctx.pipeline.F_qscale == "no" cond &= kernel_ctx.pipeline.F_skip == "f" - cond &= kernel_ctx.pipeline.F_sink == "f" + cond &= kernel_ctx.pipeline.F_sink == "none" return cond return Product(name="Flash attention integration", rule=fit) @@ -1500,7 +1525,12 @@ def fit(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool: def get_fwd_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + sink_modes=("none",), ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool() @@ -1522,7 +1552,10 @@ def get_fwd_blobs( ) for tile, pipeline in itertools.product( - tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) + tiles, + factory.get_pipelines( + dtype, hdim, hdim_v, receipt, mask_impl, sink_modes + ), ): problem_ctx = ProblemContext( dtype=dtype, mode=mode, hdim=hdim, hdim_v=hdim_v @@ -1585,9 +1618,10 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -1601,10 +1635,11 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 793a743df74..505640514ce 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -493,7 +493,11 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: appendkv does not support sink kernels. + # The parameter exists only for API uniformity with other fwd handlers. + # Non-"none" modes are silently treated as "none". api_pool, kernels = get_fwd_appendkv_blobs( targets, kernel_filter, receipt, mask_impl, optdim_list ) @@ -509,11 +513,16 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: + # sink_modes is intentionally not forwarded: see write_blobs for rationale. with file_path.open("a") as f: _, kernels = get_fwd_appendkv_blobs( targets, kernel_filter, receipt, mask_impl, optdim_list ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") - f.write((file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + "\n") + f.write( + (file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME).as_posix() + + "\n" + ) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c9bac50da10..51b30b88764 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -32,6 +32,9 @@ FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, + SINK_MODE_MAP, + SINK_MODE_DISPATCH_MAP, + SINK_NAME_MAP, ) @@ -74,7 +77,7 @@ kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, {F_occupancy}, - {F_sink}>; + {F_sink_mode}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -118,7 +121,7 @@ }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -280,8 +283,8 @@ """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink_mode}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -427,7 +430,7 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f - F_sink: str # t/f + F_sink: str # "none" / "stream" / "gptoss" / "both" F_mask: str # value from MASK_MAP @property @@ -488,10 +491,7 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -574,7 +574,9 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -675,7 +677,7 @@ def template(self) -> str: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -740,7 +742,9 @@ def filename(self) -> str: class KernelComponentFactoryBase: @staticmethod - def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: + def get_pipelines( + dtype, hdim, mask_impl, sink_modes=("none",) + ) -> List[FmhaFwdSplitKVPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let "t" padding to appear later!! @@ -754,7 +758,7 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], - ["t", "f"], + sink_modes, ): pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip @@ -764,8 +768,8 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "none", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -891,7 +895,12 @@ def get_factory(target: str): def get_fwd_splitkv_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list + targets: List[str], + kernel_filter: Optional[str], + receipt, + mask_impl, + optdim_list, + sink_modes=("none",), ) -> List[FmhaFwdSplitKVKernel]: Kernel = FmhaFwdSplitKVKernel @@ -907,7 +916,7 @@ def get_fwd_splitkv_blobs( for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl, sink_modes): if mode == "group": if pipeline.F_spad != "t" or pipeline.F_skpad != "t": # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not @@ -942,7 +951,7 @@ def get_fwd_splitkv_blobs( # FlashAttention splitkv paths use softcap-disabled kernels only. cond &= pipeline.F_logits == "f" cond &= pipeline.F_squant == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # PyTorch integration @@ -952,7 +961,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1058,6 +1067,7 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) @@ -1068,7 +1078,7 @@ def write_blobs( for kernel in combine_kernels: write_single_kernel(kernel, output_dir) kernels = get_fwd_splitkv_blobs( - targets, filter_list[1], receipt, mask_impl, optdim_list + targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes ) for kernel in kernels: write_single_kernel(kernel, output_dir) @@ -1129,6 +1139,7 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: filter_list = filter_list.split("@") filter_list.extend([""] * (2 - len(filter_list))) @@ -1140,7 +1151,7 @@ def list_blobs( for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") kernels = get_fwd_splitkv_blobs( - targets, filter_list[1], receipt, mask_impl, optdim_list + targets, filter_list[1], receipt, mask_impl, optdim_list, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 1ac1f1c38a7..1bd5c0787d9 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -31,6 +31,9 @@ FMHA_FWD_API_PER_ARCH, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, + SINK_MODE_MAP, + SINK_MODE_DISPATCH_MAP, + SINK_NAME_MAP, ) @@ -66,7 +69,7 @@ {F_squant}, {F_occupancy}, {F_skip}, - {F_sink}>; + {F_sink_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -101,7 +104,7 @@ ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink_mode}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -130,9 +133,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_stream_sink == {F_stream_sink}) && (t.has_gptoss_sink == {F_gptoss_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink_mode}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -258,7 +261,7 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false - F_sink: str # true/false + F_sink: str # "none" / "stream" / "gptoss" / "both" @property def name(self) -> str: @@ -323,10 +326,7 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" + n += SINK_NAME_MAP[self.F_sink] return n @@ -370,7 +370,9 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], - F_sink=BOOL_MAP[trait.sink], + F_sink_mode=SINK_MODE_MAP[trait.sink], + F_stream_sink=SINK_MODE_DISPATCH_MAP[trait.sink][0], + F_gptoss_sink=SINK_MODE_DISPATCH_MAP[trait.sink][1], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -488,7 +490,7 @@ def template(self) -> str: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], - F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_sink_mode=SINK_MODE_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -541,7 +543,9 @@ def api_trait(self) -> FmhaFwdApiTrait: class KernelComponentFactoryBase: @staticmethod - def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: + def get_pipelines( + dtype, hdim, mask_impl, sink_modes=("none",) + ) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr_pagedkv pipeline, let "t" padding to appear later!! @@ -555,17 +559,17 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: BIAS_MAP.keys(), ["t"], ["f"], - ["t", "f"], + sink_modes, ): pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels - for logits, mask, bias in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() + for logits, mask, bias, sink in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), sink_modes ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", sink)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -655,7 +659,12 @@ def get_factory(target: str): def get_fwd_blobs( - targets: List[str], kernel_filter: Optional[str], receipt, optdim_list, mask_impl + targets: List[str], + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + sink_modes=("none",), ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) @@ -669,7 +678,7 @@ def get_fwd_blobs( for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl, sink_modes): # if pipeline.F_pagedkv == "f": # continue if mode == "group": @@ -709,7 +718,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # PyTorch integration @@ -719,7 +728,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" + cond &= pipeline.F_sink == "none" if not cond: continue # Aiter(mha_fwd) integration @@ -773,9 +782,10 @@ def write_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: api_pool, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) @@ -789,10 +799,11 @@ def list_blobs( receipt, optdim_list, mask_impl, + sink_modes=("none",), ) -> None: with file_path.open("a") as f: _, kernels = get_fwd_blobs( - targets, kernel_filter, receipt, optdim_list, mask_impl + targets, kernel_filter, receipt, optdim_list, mask_impl, sink_modes ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp index 7d7d01bd051..b23d047f956 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1399,8 +1399,8 @@ template + bool kSkipMinSeqlenQ_ = false, + ck_tile::FmhaSinkMode kSinkMode_ = ck_tile::FmhaSinkMode::kNone> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1426,7 +1426,9 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = ck_tile::FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template + ck_tile::FmhaSinkMode::kNone> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; @@ -1512,8 +1514,8 @@ template + bool kSkipMinSeqlenQ_ = false, + ck_tile::FmhaSinkMode kSinkMode_ = ck_tile::FmhaSinkMode::kNone> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1538,7 +1540,9 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr ck_tile::FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = ck_tile::FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template @@ -1561,7 +1565,7 @@ template ::kHasStreamSink; + static constexpr bool kHasGptOssSink = ck_tile::FmhaSinkModeHelper::kHasGptOssSink; }; template @@ -1676,7 +1682,8 @@ struct fmha_fwd_traits bool has_dropout; quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; - bool has_sink = false; + bool has_stream_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1695,7 +1702,8 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; - bool has_sink = false; + bool has_stream_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; @@ -1715,7 +1723,8 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant = false; - bool has_sink = false; + bool has_stream_sink = false; // StreamLLM sliding-window sink + bool has_gptoss_sink = false; // GPT-OSS learnable softmax bias sink // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 40b80063810..0eeb4470ba5 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1133,8 +1133,11 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; - traits.has_sink = mask.sink > 0 ? true : false; - traits.has_lse = lse; + traits.has_stream_sink = mask.sink > 0 ? true : false; // StreamLLM sink + // GPT-OSS sink: callers must pass exactly 0 (not epsilon) to disable. + // This is an integer-exact float comparison; any non-zero value enables the sink. + traits.has_gptoss_sink = init_sink_value != 0; + traits.has_lse = lse; if constexpr(std::is_same_v>) { diff --git a/projects/composablekernel/example/ck_tile/01_fmha/generate.py b/projects/composablekernel/example/ck_tile/01_fmha/generate.py index a5a2d085635..88492086fbe 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/generate.py +++ b/projects/composablekernel/example/ck_tile/01_fmha/generate.py @@ -37,6 +37,10 @@ class HandlerId(IntEnum): ) assert 0 < len(handlers) +# APIs that do not support sink modes (sink_modes is not forwarded to their handlers). +# These APIs predate the StreamLLM/GPT-OSS sink split and have no sink kernel variants. +_APIS_WITHOUT_SINK = {"bwd"} + def write_blobs( targets: List[str], @@ -46,6 +50,7 @@ def write_blobs( optdim_list: List[int], receipt, mask_impl, + sink_modes=("none",), ) -> None: if output_dir is None: output_dir = Path(__file__).parent @@ -56,7 +61,18 @@ def write_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) + if api in _APIS_WITHOUT_SINK: + handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl) + else: + handler( + targets, + output_dir, + kernel_filter, + receipt, + optdim_list, + mask_impl, + sink_modes, + ) # list all the files that will be generated @@ -68,6 +84,7 @@ def list_blobs( optdim_list: List[int], receipt, mask_impl, + sink_modes=("none",), ) -> None: assert output_file is not None file_path = Path(output_file) @@ -77,7 +94,18 @@ def list_blobs( for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) + if api in _APIS_WITHOUT_SINK: + handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl) + else: + handler( + targets, + file_path, + kernel_filter, + receipt, + optdim_list, + mask_impl, + sink_modes, + ) if __name__ == "__main__": @@ -150,12 +178,29 @@ def list_blobs( + "eg. --optdim=32,64,128,256", ) + parser.add_argument( + "--sink", + default="none", + required=False, + help="comma-separated list of sink modes to generate instances for. " + + "Valid values: none, stream, gptoss, both. Default: none (only no-sink instances). " + + "eg. --sink=none,stream,gptoss", + ) + args = parser.parse_args() targets = args.targets.split(",") api_list = args.direction.split(",") filter_list = args.filter.split(",") filter_list.extend([""] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(",")] + sink_modes = tuple(s.strip() for s in args.sink.split(",")) + valid_sink_modes = {"none", "stream", "gptoss", "both"} + invalid_sink_modes = set(sink_modes) - valid_sink_modes + if invalid_sink_modes: + parser.error( + f"Invalid sink mode(s): {sorted(invalid_sink_modes)}. " + f"Valid values are: {sorted(valid_sink_modes)}" + ) if args.list_blobs is not None: list_blobs( @@ -166,6 +211,7 @@ def list_blobs( optdim_list, int(args.receipt), mask_impl=args.mask, + sink_modes=sink_modes, ) else: write_blobs( @@ -176,4 +222,5 @@ def list_blobs( optdim_list, int(args.receipt), mask_impl=args.mask, + sink_modes=sink_modes, ) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 1e9942a6e1b..8d73957a399 100755 --- a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -17,6 +17,16 @@ fi export CK_WARMUP=0 export CK_REPEAT=1 +# Detect ROCm 7.1.x: compiler VGPR aliasing miscompile affects fp16 d256 batch+dropout+bias+mask +_hip_ver=$(hipcc --version 2>/dev/null | awk '/HIP version/{print $NF}') +_hip_major=${_hip_ver%%.*} +_hip_minor=${_hip_ver#*.}; _hip_minor=${_hip_minor%%.*} +SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT=0 +if [ "${_hip_major}" = "7" ] && [ "${_hip_minor}" = "1" ]; then + SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT=1 + echo ">>> ROCm 7.1.x detected: skipping fp16 d256 batch bias=e mask dropout cases (VGPR aliasing miscompile)" +fi + CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} rm -f $CURR_FAILS_FILE touch $CURR_FAILS_FILE @@ -28,10 +38,33 @@ COMMON_ARGS='-v=1 -warmup=0 -repeat=1' TEST_SPLITKV=0 TEST_APPENDKV=0 -# options: -# -s: run splitkv tests -# -a: run appendkv tests -while getopts ":sa" opt; do +TEST_STREAM_SINK=0 +TEST_GPTOSS_SINK=0 +# Usage: +# bash smoke_test_fwd.sh [options] +# +# Options: +# -s run splitkv tests +# -a run appendkv tests +# -m run StreamLLM sliding-window sink mask tests +# requires the binary to be built with FMHA_FWD_SINK_MODES containing "stream", e.g.: +# cmake -DFMHA_FWD_SINK_MODES="none,stream" ... +# cmake --build . --target tile_example_fmha_fwd +# bash smoke_test_fwd.sh -m +# -g run GPT-OSS sink (init_sink) tests +# requires the binary to be built with FMHA_FWD_SINK_MODES containing "gptoss", e.g.: +# cmake -DFMHA_FWD_SINK_MODES="none,gptoss" ... +# cmake --build . --target tile_example_fmha_fwd +# bash smoke_test_fwd.sh -g +# +# Examples: +# bash smoke_test_fwd.sh # default: base tests only +# bash smoke_test_fwd.sh -s # also run splitkv tests +# bash smoke_test_fwd.sh -a # also run appendkv tests +# bash smoke_test_fwd.sh -m # also run StreamLLM sink tests +# bash smoke_test_fwd.sh -g # also run GPT-OSS sink tests +# bash smoke_test_fwd.sh -s -a -m -g # run all tests +while getopts ":samg" opt; do case "${opt}" in s) TEST_SPLITKV=1 @@ -39,6 +72,12 @@ while getopts ":sa" opt; do a) TEST_APPENDKV=1 ;; + m) + TEST_STREAM_SINK=1 + ;; + g) + TEST_GPTOSS_SINK=1 + ;; *) ;; esac @@ -81,7 +120,11 @@ run_fp16_bf16_tests() { run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # Skip fp16 d256 batch bias=e mask dropout on ROCm 7.1.x (VGPR aliasing miscompile bug) + if ! { [ "$SKIP_ROCM_71_FP16_D256_BATCH_DROPOUT" = "1" ] && [ "$prec" = "fp16" ] && \ + [ "$mode" = "0" ] && [ "$hdim" = "256" ] && [ "$bias" = "e" ] && [ "$p_drop" != "0.0" ]; }; then run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + fi run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS @@ -300,8 +343,18 @@ run_padding_smoke_tests run_padding_basic_boundary_tests run_fp8bf16_tests run_fp8fp32_tests -run_sink_mask_tests -run_sink_init_tests + +if [ $TEST_STREAM_SINK -eq 1 ] ; then + run_sink_mask_tests +else + echo ">>> Skipping StreamLLM sink tests (use -m to enable)" +fi + +if [ $TEST_GPTOSS_SINK -eq 1 ] ; then + run_sink_init_tests +else + echo ">>> Skipping GPT-OSS sink tests (use -g to enable)" +fi if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/projects/composablekernel/include/ck_tile/core/config.hpp b/projects/composablekernel/include/ck_tile/core/config.hpp index 06220d27800..34aac7f9639 100644 --- a/projects/composablekernel/include/ck_tile/core/config.hpp +++ b/projects/composablekernel/include/ck_tile/core/config.hpp @@ -209,6 +209,18 @@ #endif #endif +// workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm 7.1.x) +// Philox RNG VGPR parameters get aliased under high register pressure (d256 tile). +// fp16 is affected; bf16 is not (different type conversion codegen path). +// Fixed in ROCm 7.2. +#ifndef CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE +#if(HIP_VERSION_MAJOR == 7 && HIP_VERSION_MINOR == 1) +#define CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE 0 +#endif +#endif + // workaround for AMDGPU compiler VGPR aliasing bug in dropout codegen (ROCm >= 7.12) // Philox RNG VGPR parameters get aliased under high register pressure (d256 tile). // fp16 is affected; bf16 is not (different type conversion codegen path). diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 16f5b00bb11..fc846690e28 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -67,7 +67,7 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -1439,10 +1439,14 @@ struct FmhaFwdKernel long_index_t batch_offset_q_descale = 0; long_index_t batch_offset_k_descale = 0; long_index_t batch_offset_v_descale = 0; - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + // GPT-OSS sink value: only computed and passed when kHasGptOssSink=true. + // When kHasGptOssSink=false, pipelines do not use sink_value at all. + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { @@ -2184,10 +2188,12 @@ struct FmhaFwdKernel constexpr bool PrefillCase = FmhaPipeline::kM0 > 64; // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); const index_t i_m0 = i_tile_m * FmhaPipeline::kM0; const index_t i_n1 = i_tile_n * FmhaPipeline::kN1; diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 89bd22c4715..31695ac4db1 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -55,7 +56,8 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; + static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -102,7 +104,8 @@ struct FmhaFwdPagedKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + + ck_tile::FmhaSinkNameSuffix(); #undef _SS_ #undef _TS_ // clang-format on @@ -917,10 +920,12 @@ struct FmhaFwdPagedKVKernel long_index_t batch_offset_lse = 0; long_index_t batch_offset_o = 0; index_t kv_l2p_offset = 0; - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index d6c4d70fee5..bc979a21576 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include #include @@ -51,7 +52,8 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; + static constexpr bool kHasStreamSink = FmhaPipeline::Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaPipeline::Problem::kHasGptOssSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -102,7 +104,8 @@ struct FmhaFwdSplitKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + + ck_tile::FmhaSinkNameSuffix(); #undef _SS_ #undef _TS_ // clang-format on @@ -619,10 +622,12 @@ struct FmhaFwdSplitKVKernel long_index_t batch_offset_o_acc = 0; index_t kv_l2p_offset = 0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache - const float sink_value = - kargs.sink_ptr != nullptr - ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s - : -numeric::infinity(); + const float sink_value = [&]() { + if constexpr(kHasGptOssSink) + return (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s; + else + return -numeric::infinity(); + }(); if constexpr(kIsGroupMode) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index a8b94b6e417..04d7b415f2a 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -291,6 +291,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr auto kSinkMode = Problem::kSinkMode; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] // This avoids explicit P *= scale_p and v_descale /= scale_p operations @@ -526,8 +529,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { + // sink_v is always valid when kHasGptOssSink=true; no isinf check needed. #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || BiasEnum == BlockAttentionBiasEnum::ALIBI) @@ -561,7 +565,9 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + // Already inside if constexpr(kHasGptOssSink); sink_v is always valid. + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 3f6b9bc44fc..9f281820b9e 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -58,7 +58,9 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr auto kSinkMode = Problem::kSinkMode; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,7 +231,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -249,7 +251,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -275,7 +277,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -314,7 +317,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -360,7 +363,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -530,7 +533,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 1af244751ba..0360867edf2 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -57,7 +57,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -255,10 +256,18 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + if constexpr(kHasGptOssSink) { - set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); - set_tile(l, SMPLComputeDataType{1.0f}); + if(i_split == 0) + { + set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -268,7 +277,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); else @@ -293,7 +302,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); if(get_thread_local_1d_id() < kM0) { store_tile(lse_acc_dram_window_tmp, @@ -310,10 +320,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { auto [start, end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); - if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + if constexpr(kHasGptOssSink) { - set_tile(m, SMPLComputeDataType{sink_v}); - set_tile(l, SMPLComputeDataType{1.0f}); + if(start >= end) + { + set_tile(m, SMPLComputeDataType{sink_v}); + set_tile(l, SMPLComputeDataType{1.0f}); + } } } const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; @@ -345,7 +358,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -409,7 +422,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS clear_tile(s_acc); // initialize C const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -585,7 +598,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask( [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 842b48013a2..69c16d117bc 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -57,7 +57,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,18 +230,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0) + if constexpr(kHasGptOssSink) { + if(i_split == 0) + { #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); - else - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else - set_tile(m, sink_v); + set_tile(m, sink_v); #endif - set_tile(l, SMPLComputeDataType{1.0f}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -250,7 +259,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); else @@ -277,7 +286,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc)); } @@ -292,18 +302,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS { auto [start, end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split - 1); - if((__builtin_isinf_sign(sink_v) >= 0) && start >= end) + if constexpr(kHasGptOssSink) { + if(start >= end) + { #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - set_tile(m, sink_v * C_LOG2E * scale_s); - else - set_tile(m, sink_v * C_LOG2E); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + set_tile(m, sink_v * C_LOG2E * scale_s); + else + set_tile(m, sink_v * C_LOG2E); #else - set_tile(m, sink_v); + set_tile(m, sink_v); #endif - set_tile(l, SMPLComputeDataType{1.0f}); + set_tile(l, SMPLComputeDataType{1.0f}); + } + else + { + set_tile(m, -numeric::infinity()); + clear_tile(l); + } } else { @@ -341,7 +359,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return kv_load_start; else return logical_seqlen_k_start - @@ -388,7 +406,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); const auto k_move_offset = [&]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; else return kN0; @@ -562,7 +580,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask( [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9e..e871b2c1289 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -72,7 +72,9 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr auto QScaleEnum = Traits::QScaleEnum; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr bool kHasSink = Traits::kHasSink; + static constexpr auto kSinkMode = Traits::kSinkMode; + static constexpr bool kHasStreamSink = Traits::kHasStreamSink; + static constexpr bool kHasGptOssSink = Traits::kHasGptOssSink; }; template = 0) + if constexpr(kHasGptOssSink) { + // sink_v is always valid when kHasGptOssSink=true; no isinf check needed. #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -313,7 +315,7 @@ struct BlockFmhaPipelineQRKSVS const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -342,9 +344,12 @@ struct BlockFmhaPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + // Precompute sink_lse to avoid sink_v*scale_s in set_tile, which + // triggers a compiler register-allocation bug for dropout kernels. + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); } else { @@ -664,7 +669,7 @@ struct BlockFmhaPipelineQRKSVS #endif } } - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); @@ -695,7 +700,7 @@ struct BlockFmhaPipelineQRKSVS }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto&&... args) { return variant.LogitsSinkMask(std::forward(args)...); @@ -834,7 +839,7 @@ struct BlockFmhaPipelineQRKSVS auto randval_ptr = reinterpret_cast(smem_ptr); index_t seq_offset = [&]() { - if constexpr(!kHasSink) + if constexpr(!kHasStreamSink) return seqlen_k_start + i_total_loops * kN0; const bool in_sink_phase = (num_sink_loop > i_total_loops); @@ -980,7 +985,7 @@ struct BlockFmhaPipelineQRKSVS }); } // move K tile windows - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7b97d01fa4f..41d460f83b6 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -63,7 +63,8 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; // For BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] static constexpr float OCP_FP8_SHIFT = 8.0f; @@ -292,7 +293,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI || @@ -313,7 +314,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(0); const auto q_origin = q_dram_window.get_window_origin(); const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) return mask.GetSinkTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}); else @@ -343,7 +344,8 @@ struct BlockFmhaPipelineQRKSVSAsync { auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse, sink_lse); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -534,7 +536,7 @@ struct BlockFmhaPipelineQRKSVSAsync #endif } } - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); @@ -566,7 +568,7 @@ struct BlockFmhaPipelineQRKSVSAsync }); }; - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { apply_mask([&](auto&&... args) { return variant.LogitsSinkMask(std::forward(args)...); @@ -747,7 +749,7 @@ struct BlockFmhaPipelineQRKSVSAsync reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); index_t seq_offset = [&]() { - if constexpr(!kHasSink) + if constexpr(!kHasStreamSink) return seqlen_k_start + i_total_loops * kN0; const bool in_sink_phase = (num_sink_loop > i_total_loops); @@ -845,7 +847,7 @@ struct BlockFmhaPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - if constexpr(kHasSink) + if constexpr(kHasStreamSink) { if(i_total_loops == 0) { diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index c0d5ca291f0..04fbda1bdc0 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -69,7 +69,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; - static constexpr bool kHasSink = Problem::kHasSink; + static constexpr bool kHasStreamSink = Problem::kHasStreamSink; + static constexpr bool kHasGptOssSink = Problem::kHasGptOssSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -194,7 +195,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -228,7 +229,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, lse_acc); } @@ -714,7 +716,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload auto l = MLBlockTileType{}; clear_tile(o_acc); - if(__builtin_isinf_sign(sink_v) >= 0) + if constexpr(kHasGptOssSink) { #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -748,7 +750,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload { auto lse_acc = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s}); + const SMPLComputeDataType sink_lse = sink_v * scale_s; + set_tile(lse_acc, sink_lse); store_tile(lse_acc_dram_window_tmp, lse_acc); } diff --git a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 0670985e4f3..156c4eb53e0 100644 --- a/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -11,6 +11,47 @@ namespace ck_tile { +// Compile-time sink mode for FMHA forward kernels. +// Two independent sink mechanisms that can be combined: +// kStreamLLM: Sliding-window attention with initial-token sink (ICLR 2024, MIT HAN Lab) +// Controls KV tile loading schedule and per-pixel mask logic. +// kGptOss: GPT-OSS learnable softmax bias (OpenAI); initializes softmax m/l. +enum class FmhaSinkMode : int +{ + kNone = 0, // No sink — zero overhead + kStreamLLM = 1, // StreamingLLM sliding-window sink + kGptOss = 2, // GPT-OSS learnable softmax bias sink + kBoth = 3, // Both sinks simultaneously +}; + +// Helper that derives kHasStreamSink / kHasGptOssSink from a FmhaSinkMode value. +// Use as a base or member in traits structs to avoid repeating the same two-line +// derivation in every traits type. +template +struct FmhaSinkModeHelper +{ + static constexpr bool kHasStreamSink = + (kSinkMode_ == FmhaSinkMode::kStreamLLM || kSinkMode_ == FmhaSinkMode::kBoth); + static constexpr bool kHasGptOssSink = + (kSinkMode_ == FmhaSinkMode::kGptOss || kSinkMode_ == FmhaSinkMode::kBoth); +}; + +// Returns the kernel name suffix for a given sink mode combination. +// Mirrors the Python-side SINK_NAME_MAP: none->"_nsink", stream->"_ssink", +// gptoss->"_gsink", both->"_bsink". +template +constexpr const char* FmhaSinkNameSuffix() +{ + if constexpr(!HasStream && !HasGptOss) + return "_nsink"; + else if constexpr(HasStream && !HasGptOss) + return "_ssink"; + else if constexpr(!HasStream && HasGptOss) + return "_gsink"; + else + return "_bsink"; +} + template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -38,7 +79,9 @@ struct TileFmhaTraits static constexpr auto QScaleEnum = QScaleEnum_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template + FmhaSinkMode::kNone> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; @@ -110,9 +153,9 @@ template 1 or fwd training is running */ bool kIsPagedKV_, bool kDoFp8StaticQuant_, - index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ - bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ - bool kHasSink_ = false> + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaFwdPagedKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -127,7 +170,9 @@ struct TileFmhaFwdPagedKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template + FmhaSinkMode kSinkMode_ = FmhaSinkMode::kNone> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -160,7 +205,9 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; - static constexpr bool kHasSink = kHasSink_; + static constexpr FmhaSinkMode kSinkMode = kSinkMode_; + static constexpr bool kHasStreamSink = FmhaSinkModeHelper::kHasStreamSink; + static constexpr bool kHasGptOssSink = FmhaSinkModeHelper::kHasGptOssSink; }; template ) + { + if(hdim_q == 256 && hdim_v == 24 && mode == mode_enum::batch && p_drop > 0) + GTEST_SKIP() << "Skipped: fp16 dropout d256/d24 batch — compiler VGPR aliasing " + "bug (ROCm 7.1.x)"; + } +#endif + auto result = fmha_fwd_run(mode, batch, nhead, @@ -601,6 +610,14 @@ TEST_P(Dropout, DataTypeConfig) auto [drop_seed, drop_offset, drop_prefs] = drop_seed_offset_prefs; auto [batch, nhead, nhead_k, seqlen_q, seqlen_k, mask_str] = dims_mask; +#if CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE + if constexpr(std::is_same_v) + { + if(hdim_q > 128 && mode == mode_enum::batch) + GTEST_SKIP() << "Skipped: fp16 dropout d256 batch — compiler VGPR aliasing " + "bug (ROCm 7.1.x)"; + } +#endif #if CK_TILE_WORKAROUND_ROCM_7_12_FP16_DROPOUT_MISCOMPILE if constexpr(std::is_same_v) {