From c8c360abac3684b9cd78caab429af1bd9f360a47 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 24 Nov 2025 19:09:17 +0800 Subject: [PATCH 1/2] Fix batch prefill aiter compile fail --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 74db4e084c9..a98576b0536 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -20,6 +20,8 @@ FWD_DTYPE_MAP, BOOL_MAP, PIPELINE_ENUM_MAP, + QSCALE_CHECK_MAP, + QSCALE_MAP, ) from codegen.utils import update_file @@ -60,7 +62,7 @@ false, {F_lse}, {F_dropout}, - {F_squant}, + {F_qscale}, {F_occupancy}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -98,7 +100,7 @@ ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; #include @@ -175,9 +177,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; return fmha_batch_prefill_(s, a); }} """ @@ -216,7 +218,7 @@ class FmhaFwdApiTrait: bias: str # lse: str # dropout: str - squant: str # + qscale: str # spad: str skpad: str dpad: str @@ -227,7 +229,7 @@ class FmhaFwdApiTrait: def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}" ) @property @@ -312,7 +314,7 @@ class FmhaFwdPipeline: F_bias: str # true/false F_lse: str # F_dropout: str # - F_squant: str # + F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @@ -370,10 +372,10 @@ def pad_name() -> str: else: n += "_ndropout" - if self.F_squant == "t": - n += "_squant" + if self.F_qscale != "no": + n += f"_{self.F_qscale}" else: - n += "_nsquant" + n += "_nqscale" return n @@ -413,7 +415,8 @@ def api(self) -> str: F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], - F_squant=BOOL_MAP[trait.squant], + F_qscale_check=QSCALE_CHECK_MAP[trait.qscale], + F_qscale=QSCALE_MAP[trait.qscale], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -522,7 +525,7 @@ def template(self) -> str: F_bias=BIAS_MAP[self.F_pipeline.F_bias], F_lse=BOOL_MAP[self.F_pipeline.F_lse], F_dropout=BOOL_MAP[self.F_pipeline.F_dropout], - F_squant=BOOL_MAP[self.F_pipeline.F_squant], + F_qscale=QSCALE_MAP[self.F_pipeline.F_qscale], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -562,7 +565,7 @@ def api_trait(self) -> FmhaFwdApiTrait: bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, - squant=self.F_pipeline.F_squant, + qscale=self.F_pipeline.F_qscale, spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, @@ -587,7 +590,7 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = "t" if dtype == "fp8" else "f" + qscale = "no" pipelines = [] if dtype in ["fp16", "bf16"]: for logits, mask, bias, lse, dropout in itertools.product( @@ -597,10 +600,10 @@ def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: ["t", "f"], ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip - # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip - # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip + # pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask)) # fmt: skip else: assert False return pipelines @@ -672,7 +675,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" if not cond: continue # PyTorch integration @@ -680,7 +683,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" if not cond: continue # Aiter(mha_fwd) integration @@ -688,7 +691,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "batch" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" if not cond: continue # Aiter(mha_batch_prefill) integration @@ -696,7 +699,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" if not cond: continue # aiter::mha_batch_prefill C++ api integration @@ -704,7 +707,7 @@ def get_fwd_blobs( cond = dtype in ["fp16", "bf16"] cond &= mode == "group" cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" + cond &= pipeline.F_qscale == "no" if not cond: continue From a9f209b75fa557b9caf3b932e4c2bba938d3f54a Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 24 Nov 2025 09:45:32 -0600 Subject: [PATCH 2/2] Fix compile error --- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 108 +++++------------- 1 file changed, 29 insertions(+), 79 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 3b476299e15..c6fbd6945f8 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" #include @@ -53,7 +54,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; - static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -99,7 +100,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + + (QScaleEnum == BlockAttentionQuantScaleEnum::NO_SCALE ? _SS_("_nqscale") : (_SS_("_") + BlockAttentionQuantScaleEnumToStr::name)); #undef _SS_ #undef _TS_ // clang-format on @@ -202,12 +204,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::GenericAttentionMaskEnum mask_type; }; - struct FmhaFwdFp8StaticQuantKargs - { - float scale_p; - float scale_o; - }; - struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -278,9 +274,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -297,9 +292,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, - std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; ck_tile::index_t batch_stride_k; @@ -337,8 +331,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t page_block_size, #endif float scale_s, - float scale_p, - float scale_o, + [[maybe_unused]] float scale_p, + [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -401,7 +395,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout {}, // placeholder for logits_soft_cap batch_stride_q, @@ -433,11 +426,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.nhead_stride_lse = nhead_stride_lse; kargs.batch_stride_lse = batch_stride_lse; } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -489,8 +477,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t page_block_size, #endif float scale_s, - float scale_p, - float scale_o, + [[maybe_unused]] float scale_p, + [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, @@ -548,7 +536,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel {}, // placeholder for bias {}, // placeholder for mask {}, // placeholder for lse - {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), @@ -577,11 +564,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.lse_ptr = lse_ptr; kargs.nhead_stride_lse = nhead_stride_lse; } - if constexpr(kDoFp8StaticQuant) - { - kargs.scale_p = scale_p; - kargs.scale_o = scale_o; - } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1082,55 +1064,23 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; auto o_acc_tile = [&]() { - if constexpr(kDoFp8StaticQuant) - { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, - dropout); - } - else - { - return FmhaPipeline{}(q_dram_window, - k_dram_window, - v_dram_window, - bias_dram_window, - randval_dram_window, - lse_dram_window, - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - kargs.kv_page_indices, - kargs.stride_k, - kargs.stride_v, - dropout); - } + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); }(); // O DRAM and O DRAM window