diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e35c7f9c369..ce914b92afb 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS # there is no corresponding instance for parameters). if(BUILD_TESTING) # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill - list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) endif() # generate a list of kernels, but not actually emit files at config sta diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 6ef77a7c453..2acc4674108 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -66,8 +66,7 @@ {F_dropout}, {F_squant}, {F_occupancy}, - {F_skip}, - {F_sink}>; + {F_skip}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -104,7 +103,7 @@ ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip},{F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -191,9 +190,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -240,14 +239,13 @@ class FmhaFwdApiTrait: dvpad: str skip: str tr_load: str - sink: str constraint: CppConstraint @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" ) @property @@ -347,7 +345,6 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false - F_sink: str # true/false F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -418,10 +415,6 @@ def pad_name() -> str: n += "_trload" else: n += "_ntrload" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" return n @@ -469,7 +462,6 @@ def api(self) -> str: F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), @@ -596,7 +588,6 @@ def template(self) -> str: F_mode=MODE_MAP[self.F_mode], F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], - F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -639,7 +630,6 @@ def api_trait(self) -> FmhaFwdApiTrait: dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, - sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -706,51 +696,49 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines = [] if dtype in ["fp32"]: squant = "f" - for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip elif dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip else: - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "bf8"]: # TODO None @@ -769,14 +757,13 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) if dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -785,8 +772,8 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli and dropout == "f" and skip == "f" ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip return pipelines @@ -824,24 +811,23 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines = [] if dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip, sink in itertools.product( + for logits, mask, bias, lse, dropout, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], - ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip else: assert False return pipelines @@ -948,7 +934,6 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -960,7 +945,6 @@ def get_fwd_blobs( cond &= mode == "batch" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration @@ -1001,7 +985,6 @@ def get_fwd_blobs( cond = dtype == "fp32" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # fp32 only, minimal set of parameters @@ -1015,7 +998,6 @@ def get_fwd_blobs( cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" cond &= pipeline.F_mask == "s_no" - cond &= pipeline.F_sink == "f" if not cond: continue else: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 5029f1fa97c..85c25561eab 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -74,8 +74,7 @@ {F_pagedkv}, kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}, - {F_sink}>; + {F_occupancy}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -119,7 +118,7 @@ }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -281,8 +280,8 @@ """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -334,7 +333,6 @@ class FmhaFwdSplitKVApiTrait: dpad: str dvpad: str pagedkv: str - sink: str # sink or not bn1comb: int # tile size along v head_dim of combine kernel @property @@ -342,7 +340,7 @@ def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" - + f"{self.dvpad}-{self.pagedkv}-{self.sink}" + + f"{self.dvpad}-{self.pagedkv}" ) @property @@ -428,7 +426,6 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f - F_sink: str # t/f F_mask: str # value from MASK_MAP @property @@ -489,10 +486,6 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" return n @@ -575,7 +568,6 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], - F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -676,7 +668,6 @@ def template(self) -> str: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, - F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -750,23 +741,19 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv, sink in itertools.product( - ["t", "f"], - get_mask_map(mask_impl).keys(), - BIAS_MAP.keys(), - ["t", "f"], - ["t", "f"], + for logits, mask, bias, pagedkv in itertools.product( + ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] ): - pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -922,7 +909,6 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -932,7 +918,6 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" - cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1091,7 +1076,6 @@ def write_blobs( lse=kernel.F_pipeline.F_lse, squant=kernel.F_pipeline.F_squant, pagedkv=kernel.F_pipeline.F_pagedkv, - sink=kernel.F_pipeline.F_sink, spad=kernel.F_pipeline.F_spad, skpad=kernel.F_pipeline.F_skpad, dpad=kernel.F_pipeline.F_dpad, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 3ff47b940a3..17ac129e641 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -66,8 +66,7 @@ {F_pagedkv}, //pagedkv {F_squant}, {F_occupancy}, - {F_skip}, - {F_sink}>; + {F_skip}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -102,7 +101,7 @@ ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -131,9 +130,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -165,13 +164,12 @@ class FmhaFwdApiTrait: dpad: str dvpad: str skip: str - sink: str @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" ) @property @@ -259,7 +257,6 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false - F_sink: str # true/false @property def name(self) -> str: @@ -324,10 +321,6 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" - if self.F_sink == "t": - n += "_sink" - else: - n += "_nsink" return n @@ -371,7 +364,6 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], - F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -489,7 +481,6 @@ def template(self) -> str: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], - F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -536,7 +527,6 @@ def api_trait(self) -> FmhaFwdApiTrait: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, - sink=self.F_pipeline.F_sink, ) @@ -550,23 +540,22 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv, skip, sink in itertools.product( + for logits, mask, bias, pagedkv, skip in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"], - ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -690,7 +679,6 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -700,7 +688,6 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" - cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index b95148cbc92..a952800806e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -265,7 +265,6 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; - ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; @@ -352,7 +351,6 @@ struct fmha_fwd_pagedkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; - ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; }; @@ -443,7 +441,6 @@ struct fmha_fwd_splitkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; - ck_tile::index_t sink_size; ck_tile::index_t mask_type; }; @@ -614,7 +611,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_o, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type, args.min_seqlen_q, args.p_drop, @@ -664,7 +660,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -732,7 +727,6 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type, args.min_seqlen_q); } @@ -778,7 +772,6 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type); } }(); @@ -845,7 +838,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type); } else @@ -893,7 +885,6 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, - args.sink_size, args.mask_type); } }(); @@ -1140,8 +1131,7 @@ template + bool kSkipMinSeqlenQ_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1167,7 +1157,6 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; }; template @@ -1194,8 +1183,7 @@ template + bool kSkipMinSeqlenQ_ = false> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1220,7 +1208,6 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; }; template @@ -1243,7 +1230,6 @@ template @@ -1358,7 +1343,6 @@ struct fmha_fwd_traits bool has_dropout; bool do_fp8_static_quant; bool skip_min_seqlen_q = false; - bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1377,7 +1361,6 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; - bool has_sink = false; // TODO: padding check is inside this api }; @@ -1397,7 +1380,6 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; - bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index a3fc7a9611e..8a663d038d1 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -907,7 +907,6 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; - traits.has_sink = mask.sink > 0 ? true : false; traits.has_lse = lse; traits.do_fp8_static_quant = squant; @@ -1073,7 +1072,6 @@ fwd_result fmha_fwd_run(mode_enum mode, args.window_size_left = mask.left; args.window_size_right = mask.right; - args.sink_size = mask.sink; args.mask_type = static_cast(mask.type); if constexpr(std::is_same_v>) @@ -1662,7 +1660,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k)); + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); } else { @@ -1674,7 +1672,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, - mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); @@ -1684,7 +1681,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, - mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index aa30db0d6f1..2dfe0e7c529 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -25,7 +25,6 @@ struct mask_info ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right - ck_tile::index_t sink; void serialize(std::ostream& os) const { @@ -59,14 +58,13 @@ struct mask_info ck_tile::index_t window_size = std::stoi(v); ck_tile::index_t left_size = -1; ck_tile::index_t right_size = 0; - ck_tile::index_t sink_size = 0; if(window_size > 0) { left_size = window_size / 2; right_size = window_size - 1 - left_size; } auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, sink_size, y_total, x_total, t == "xt"); + left_size, right_size, y_total, x_total, t == "xt"); tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; tmp.y = r.at(ck_tile::number<0>{}); @@ -81,54 +79,27 @@ struct mask_info { throw std::invalid_argument("invalid mask value: " + str); } - tmp.type = mask_enum::window_generic; - ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); - auto found_2 = v.find(',', found_1 + 1); - ck_tile::index_t v1 = 0; - ck_tile::index_t sink = 0; - // ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); - // TODO: some validation + ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); + ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); if(t == "t") { - if(found_2 != std::string::npos) - { - v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); - sink = atoi(v.substr(found_2 + 1).c_str()); - } - else - { - v1 = atoi(v.substr(found_1 + 1).c_str()); - sink = 0; - } tmp.type = mask_enum::mask_top_left; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, sink, y_total, x_total, true); + v0, v1, y_total, x_total, true); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; - tmp.sink = sink; } else if(t == "b") { - if(found_2 != std::string::npos) - { - v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); - sink = atoi(v.substr(found_2 + 1).c_str()); - } - else - { - v1 = atoi(v.substr(found_1 + 1).c_str()); - sink = 0; - } tmp.type = mask_enum::mask_bottom_right; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, sink, y_total, x_total, false); + v0, v1, y_total, x_total, false); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; - tmp.sink = sink; } else if(t == "g") { @@ -137,7 +108,6 @@ struct mask_info tmp.x = v1; tmp.left = v0; // TODO: don't use this? tmp.right = v1; - tmp.sink = 0; } } else @@ -156,7 +126,6 @@ struct mask_info tmp.x = 1; tmp.left = -1; tmp.right = 0; - tmp.sink = 0; } else if(str == "2" || str == "b") { @@ -165,7 +134,6 @@ struct mask_info tmp.x = seqlen_k - seqlen_q + 1; tmp.left = -1; tmp.right = 0; - tmp.sink = 0; } else { diff --git a/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh deleted file mode 100644 index 712db522580..00000000000 --- a/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh +++ /dev/null @@ -1,74 +0,0 @@ -#!/bin/bash -# TODO: run this script from CK root or build directory -EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" -KNAME=1 - -export CK_WARMUP=0 -export CK_REPEAT=1 - -COMMON_ARGS='-v=1 -warmup=0 -repeat=1' -# mode=0 -# export HIP_VISIBLE_DEVICES=4 - -TEST_SPLITKV=0 -TEST_APPENDKV=0 -# options: -# -s: run splitkv tests -# -a: run appendkv tests -while getopts ":sa" opt; do - case "${opt}" in - s) - TEST_SPLITKV=1 - ;; - a) - TEST_APPENDKV=1 - ;; - *) - ;; - esac -done - -run_fp16_bf16_tests() { - local NUM_SPLITS="1" - local PAGE_BLOCK_SIZE="0" - local CACHE_BATCH_IDX="0" - - if [ $TEST_SPLITKV -eq 1 ] ; then - NUM_SPLITS="$NUM_SPLITS 2 3" - PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" - CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" - fi - - for prec in "fp16"; do - for mode in 1 0 ; do - for perm in 0 1 ; do - for vlayout in "r" "c" ; do - for batch in 1 4; do - for head in 1; do - for h_k in 1; do - for q_seq in 128 512 ; do - for kv_seq in 128 1024; do - for hdim in 32 64 128 256; do #256 - for lse in 0 1 ; do - for bias in "e" ; do - for p_drop in 0.0 0.2; do # 0.0 - for mask in "t:2,0,4" "b:1,0,2"; do - for num_splits in $NUM_SPLITS ; do - for page_block_size in $PAGE_BLOCK_SIZE ; do - for cache_batch_idx in $CACHE_BATCH_IDX ; do - - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16, -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask - - done ; done ; done ; done ; done - done ; done ; done ; done ; done - done ; done ; done ; done ; done - done ; done -} - - -set -x - -run_fp16_bf16_tests - -set +x diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 751fded2de1..5c2a5a4b3d0 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -36,7 +36,6 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh -time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh deleted file mode 100755 index b554e16ea7f..00000000000 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ /dev/null @@ -1,83 +0,0 @@ -#!/bin/bash -# TODO: run this script from CK root or build directory -#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" -set -euo pipefail - -SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) -EXE_NAME=tile_example_fmha_fwd -EXE="$(find . -name $EXE_NAME -type f | head -n 1)" -KNAME=1 -GPU_arch=$GPU_arch -if [ -z "$GPU_arch" ] ; then - GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') -fi -set -x - -COMMON_ARGS='-v=1 -warmup=0 -repeat=1' - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 - -# window_size[2,0], sink_size = 2 - -# x=1/y=3 -# 1 * * * * * * * 1 * * * * * * * -# 1 1 * * * * * * 1 1 * * * * * * -# 1 1 1 * * * * * ----> 1 1 1 * * * * * -# * 1 1 1 * * * * 1 1 1 1 * * * * -# * * 1 1 1 * * * 1 1 1 1 1 * * * -# * * * 1 1 1 * * 1 1 * 1 1 1 * * -# * * * * 1 1 1 * 1 1 * * 1 1 1 * -# * * * * * 1 1 1 1 1 * * * 1 1 1 -# l=2/r=0(tl) l=2/r=0/s=2(tl) - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 - -# x=4/y=1 -# 1 1 1 1 * * * * 1 1 1 1 * * * * -# * 1 1 1 1 * * * 1 1 1 1 1 * * * -# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * -# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * -# * * * * 1 1 1 1 1 1 * * 1 1 1 1 -# l=0/r=3(tl) l=0/r=3/s=2(tl) -# l=3/r=0(br) l=3/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 - -# x=4/y=-1 -# * * 1 1 * * * * 1 1 1 1 * * * * -# * * * 1 1 * * * 1 1 * 1 1 * * * -# * * * * 1 1 * * ----> 1 1 * * 1 1 * * -# * * * * * 1 1 * 1 1 * * * 1 1 * -# * * * * * * 1 1 1 1 * * * * 1 1 -# l=1/r=0(br) l=1/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 - -# x=-1/y=5 - -# * * * * * * * * * * * * -# * * * * * * * * * * * * -# 1 * * * * * 1 * * * * * -# 1 1 * * * * 1 1 * * * * -# 1 1 1 * * * ----> 1 1 1 * * * -# * 1 1 1 * * 1 1 1 1 * * -# * * 1 1 1 * 1 1 1 1 1 * -# * * * 1 1 1 1 1 * 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 -# x=-1/y=8 -# * * * * * * * * * * -# * * * * * * * * * * -# 1 * * * * ----> 1 * * * * -# 1 1 * * * 1 1 * * * -# 1 1 1 * * 1 1 1 * * -# 1 1 1 1 * 1 1 1 1 * -# 1 1 1 1 1 1 1 1 1 1 -# 1 1 1 1 1 1 1 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - \ No newline at end of file diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index 93329e99ce6..eece7fc3a81 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor& c_b_m_n, cons { for(int m = 0; m < M; ++m) { - if(mask.IsOutOfSinkBound(m, n)) + if(mask.IsOutOfBound(m, n)) c_b_m_n(batch, m, n) = -ck_tile::numeric::infinity(); } } diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 5484c92f014..2c45945fac0 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -86,22 +86,21 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) - : GenericAttentionMask(0, 0, 0, y_total_, x_total_) + : GenericAttentionMask(0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) + GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - sink(mask_coord.at(number<2>{})), - y_total(mask_coord.at(number<3>{})), - x_total(mask_coord.at(number<4>{})) + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) { } @@ -142,44 +141,6 @@ struct GenericAttentionMask } } - template - CK_TILE_HOST_DEVICE constexpr auto - GetSinkTileRangeAlongX(index_t i_y, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, 0, x_total); - } - else - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - if constexpr(IsLocal) - { - index_t tmp = max(-y + i_y + 1, 0); - return (tmp / XTile) * XTile; // round to tile aligned - } - else - { - return 0; - } - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - index_t tmp = min(i_y + YTile - 1 + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); - - index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; - if(x_start <= sink_seq_end && sink > 0) - return ck_tile::make_tuple(0, 0, x_end); - else - return ck_tile::make_tuple(sink_seq_end, x_start, x_end); - } - } - // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -234,30 +195,6 @@ struct GenericAttentionMask } } - CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const - { - if constexpr(!IsMasking) - return i_x >= x_total; - // no need to do min/max here, since i_x will never be < 0 or >= x_total - index_t x_start = -y + i_y + 1; - index_t x_end = min(i_y + x, x_total); - - if constexpr(IsLocal) - { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) - return false; - else - return i_x < x_start || i_x >= x_end; - } - else - { - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) - return false; - else - return i_x >= x_end || i_y >= y_total; - } - } - // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() @@ -300,7 +237,7 @@ struct GenericAttentionMask } private: - index_t y, x, sink; + index_t y, x; index_t y_total, x_total; }; @@ -323,23 +260,21 @@ struct SimplifiedGenericAttentionMask static constexpr const char* name = impl::SimplifiedMaskName::name; CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) - : SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_) + : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask( - index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) + SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - sink(mask_coord.at(number<2>{})), - y_total(mask_coord.at(number<3>{})), - x_total(mask_coord.at(number<4>{})) + y_total(mask_coord.at(number<2>{})), + x_total(mask_coord.at(number<3>{})) { } @@ -373,38 +308,6 @@ struct SimplifiedGenericAttentionMask } } - template - CK_TILE_HOST_DEVICE constexpr auto - GetSinkTileRangeAlongX(index_t i_y, number, number) const - { - if constexpr(!IsMasking) - { - return ck_tile::make_tuple(0, 0, x_total); - } - else - { - // get the tile start/end range assum we loop over along X tile by tile - index_t x_start = [&]() { - index_t tmp = max(-y + i_y + 1, 0); - return (tmp / XTile) * XTile; // round to tile aligned - }(); - - // TODO: end could be negative, we ignore clamp here, and let caller to check - // ... in which case end-start is negative - index_t x_end = [&]() { - index_t tmp = min(i_y + YTile - 1 + x, x_total); - return ((tmp + XTile - 1) / XTile) * XTile; - }(); - - index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; - - if(x_start <= sink_seq_end && sink > 0) - return ck_tile::make_tuple(0, 0, x_end); - else - return ck_tile::make_tuple(sink_seq_end, x_start, x_end); - } - } - template CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number height, @@ -422,29 +325,6 @@ struct SimplifiedGenericAttentionMask ck_tile::min(origin_end, split_end)); } - template - CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y, - number height, - number width, - index_t num_splits, - index_t i_split) const - { - auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); - const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); - const index_t split_start = x_per_split * i_split; // 128 - const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256 - const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0; - const index_t start = ck_tile::max(origin_start, split_start); - const index_t end = ck_tile::min(origin_end, split_end); - const bool is_first_intersecting_split = - (split_start <= origin_start && split_end >= origin_start); - const bool sink_in_range = (sink_seq_end <= start); - - const index_t sink_offset = - (is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0; - return ck_tile::make_tuple(sink_offset, start, end); - } - // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -488,20 +368,9 @@ struct SimplifiedGenericAttentionMask { index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; - } - } - CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const - { - if constexpr(!IsMasking) - return i_x >= x_total; - index_t x_start = -y + i_y + 1; // this could be negative, but it's fine - index_t x_end = min(i_y + x, x_total); // need min in case x is padded - if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) - return false; - else return i_x < x_start || i_x >= x_end || i_y >= y_total; + } } // if current tile is at the edge, means need per-pixel mask check. @@ -537,7 +406,7 @@ struct SimplifiedGenericAttentionMask } private: - index_t y, x, sink; + index_t y, x; index_t y_total, x_total; }; @@ -738,7 +607,6 @@ struct SimplifiedRatioAttentionMask CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, - index_t sink_size, index_t y_total, index_t x_total, bool is_top_left = true) @@ -756,21 +624,7 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t x = 1 + right_size + x_tmp; index_t y = 1 + left_size + y_tmp; - return ck_tile::make_tuple(y, x, sink_size, y_total, x_total); -} - -template -CK_TILE_HOST_DEVICE constexpr auto -make_generic_attention_mask_from_lr_window(index_t left_size, - index_t right_size, - index_t sink_size, - index_t y_total, - index_t x_total, - bool is_top_left = true) -{ - auto r = make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, sink_size, y_total, x_total, is_top_left); - return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total}; + return ck_tile::make_tuple(y, x, y_total, x_total); } template @@ -782,7 +636,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size, bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, 0, y_total, x_total, is_top_left); - return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total}; + left_size, right_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index 245f5dc5682..d8b0cdbb86b 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -162,17 +162,6 @@ struct StandardAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } - - template - __device__ __forceinline__ bool LogitsSinkMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); - } }; template @@ -235,17 +224,6 @@ struct LogitsSoftCap { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } - - template - __device__ __forceinline__ bool LogitsSinkMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); - } }; constexpr uint32_t CUSTOM_MASK = 1U; @@ -319,17 +297,6 @@ struct ComposedAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } - - template - __device__ __forceinline__ bool LogitsSinkMask(const Params& params, - [[maybe_unused]] uint32_t batch_idx, - uint32_t qo_idx, - uint32_t kv_idx, - [[maybe_unused]] uint32_t qo_head_idx, - [[maybe_unused]] uint32_t kv_head_idx) const - { - return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index cd5b180a39d..3b476299e15 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -198,7 +198,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right, sink_size; + ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -362,7 +362,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -426,7 +425,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -511,7 +509,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -573,7 +570,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -1030,7 +1026,6 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, - kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 414eae3651e..fba3065842a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -56,7 +56,6 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -113,7 +112,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload") + (kHasSink ? "_sink" : "_nsink"); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ // clang-format on @@ -201,7 +200,7 @@ struct FmhaFwdKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right, sink_size; + ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -375,7 +374,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -434,7 +432,6 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -521,7 +518,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -569,7 +565,6 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, - sink_size, mask_type, p_drop, s_randval, @@ -620,7 +615,6 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -668,7 +662,6 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, - sink_size, mask_type, p_drop, s_randval, @@ -713,7 +706,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -773,7 +765,6 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -857,7 +848,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -901,7 +891,6 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, - sink_size, mask_type, min_seqlen_q, p_drop, @@ -948,7 +937,6 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -992,7 +980,6 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, - sink_size, mask_type, min_seqlen_q, p_drop, @@ -1484,7 +1471,6 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, - kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); @@ -2214,7 +2200,6 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, - kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 712892387cc..a2e6f083616 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -55,7 +55,6 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -102,7 +101,7 @@ struct FmhaFwdPagedKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ #undef _TS_ // clang-format on @@ -190,7 +189,7 @@ struct FmhaFwdPagedKVKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right, sink_size; + ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -327,7 +326,6 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -381,7 +379,6 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -456,7 +453,6 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type) { return MakeKargsImpl(q_ptr, @@ -499,7 +495,6 @@ struct FmhaFwdPagedKVKernel batch_stride_o, window_size_left, window_size_right, - sink_size, mask_type); } @@ -541,7 +536,6 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -596,7 +590,6 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -667,7 +660,6 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -707,7 +699,6 @@ struct FmhaFwdPagedKVKernel batch_stride_v, window_size_left, window_size_right, - sink_size, mask_type, min_seqlen_q); } @@ -1266,7 +1257,6 @@ struct FmhaFwdPagedKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, - kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 1b5deb26e0c..a6e44c7293c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -51,7 +51,6 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; - static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -102,7 +101,7 @@ struct FmhaFwdSplitKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ #undef _TS_ // clang-format on @@ -199,7 +198,7 @@ struct FmhaFwdSplitKVKernel struct MaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right, sink_size; + ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -326,7 +325,6 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -386,7 +384,6 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -454,7 +451,6 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -512,7 +508,6 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; - kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -999,7 +994,6 @@ struct FmhaFwdSplitKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, - kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index d73c673a58a..7a8e9a1d470 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -57,7 +57,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,22 +228,10 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS clear_tile(o_acc); set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) - return mask.GetSinkTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}); - else - { - auto [start, end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - return ck_tile::make_tuple(0, start, end); - } - }(); - const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); - const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); - const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -268,6 +255,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return o_acc; } } + // k_dram_block_window const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; @@ -286,36 +274,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); - const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) - ? aligned_physical_seqlen_k_start - : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + - num_sink_loop; + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {kv_load_start, 0}); - - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) - return kv_load_start; - else - return logical_seqlen_k_start - - (physical_seqlen_k_start - aligned_physical_seqlen_k_start); - }(); + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), bias_n_offset}, + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); // v_dram_window auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, kv_load_start}, // TODO: hdim split? + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); + auto q_tile = tile_elementwise_in(q_element_func, q); // prefetch K tile @@ -342,16 +321,9 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } - const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); - const auto k_move_offset = [&]() { - if constexpr(kHasSink) - return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; - else - return kN0; - }(); auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {k_move_offset, 0})); + i_page_block_k, k_dram_block_window, {kN0, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -470,7 +442,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, k_move_offset}); + move_tile_window(bias_dram_window, {0, kN0}); { const auto k_origin = k_page_block_navigator.to_global_window_origin( @@ -502,29 +474,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS number{}); if(need_perpixel_check) { - auto apply_mask = [&](auto&& mask_func) { - set_tile_if(s_acc, - -numeric::infinity(), - [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask_func(row, col - kv_l2p_offset); - }); - }; - - if constexpr(kHasSink) - { - apply_mask([&](auto row, auto col) { - return mask.IsOutOfSinkBound(row, col); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); }); - } - else - { - apply_mask( - [&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); - } } } } @@ -690,12 +647,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); - physical_next_block_id_v = - amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( - i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); - i_page_block_v = v_page_block_navigator.move_tile_window( - i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); + i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 65e0eb7dfb3..4d1c38e0790 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -57,7 +57,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -257,23 +256,11 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) - return mask.GetSinkTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - else - { - auto [start, end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - return ck_tile::make_tuple(0, start, end); - } - }(); - const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); - const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); - const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { const index_t logical_num_total_loop = @@ -317,33 +304,24 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS return physical_seqlen_k_start_; } }(); - const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) - ? aligned_physical_seqlen_k_start - : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + - num_sink_loop; + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {kv_load_start, 0}); + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) - return kv_load_start; - else - return logical_seqlen_k_start - - (physical_seqlen_k_start - aligned_physical_seqlen_k_start); - }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), bias_n_offset}, + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, kv_load_start}, // TODO: hdim split? + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // store Q into LDS @@ -391,13 +369,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // STAGE 1, QK gemm clear_tile(s_acc); // initialize C - const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); - const auto k_move_offset = [&]() { - if constexpr(kHasSink) - return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; - else - return kN0; - }(); + // load the second tile of the first iteration k_block_tile = load_tile(k_dram_window); @@ -522,7 +494,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, k_move_offset}); + move_tile_window(bias_dram_window, {0, kN0}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -533,7 +505,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, + physical_seqlen_k_start_ = physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -558,26 +530,12 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS number{}); if(need_perpixel_check) { - auto apply_mask = [&](auto&& mask_func) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask_func(row, col - kv_l2p_offset); - }); - }; - - if constexpr(kHasSink) - { - apply_mask( - [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); - } - else - { - apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); - } + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); + }); } } @@ -588,7 +546,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {k_move_offset, 0}); + i_page_block_k, k_dram_block_window, {kN0, 0}); k_dram_window = make_tile_window( k_dram_block_window, @@ -784,8 +742,6 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS // moving k_dram_window is an in-page-block operation, so there is // no need to invoke k_page_block_navigator.move_tile_window() here. move_tile_window(k_dram_window, {0, kK0}); - i_page_block_v = v_page_block_navigator.move_tile_window( - i_page_block_v, v_dram_window, {0, k_move_offset - kN0}); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); } } while(++i_total_loops < num_total_loop); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 7eb8872022e..fe5e0bc3452 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -56,7 +56,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; - static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -230,23 +229,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { - if constexpr(kHasSink) - return mask.GetSinkTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - else - { - auto [start, end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - return ck_tile::make_tuple(0, start, end); - } - }(); - const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); - const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); - const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - - const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) @@ -289,35 +274,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); - const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) - ? aligned_physical_seqlen_k_start - : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + - num_sink_loop; + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {kv_load_start, 0}); + k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); - - const index_t bias_n_offset = [&]() { - if constexpr(kHasSink) - return kv_load_start; - else - return logical_seqlen_k_start - - (physical_seqlen_k_start - aligned_physical_seqlen_k_start); - }(); - auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), bias_n_offset}, + {bias_origin.at(number<0>{}), + logical_seqlen_k_start - (physical_seqlen_k_start - + aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, kv_load_start}, // TODO: hdim split? + {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -346,18 +320,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } - const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); - - const auto k_move_offset = [&]() { - if constexpr(kHasSink) - return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; - else - return kN0; - }(); - auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {k_move_offset, 0})); + i_page_block_k, k_dram_block_window, {kN0, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -476,7 +441,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, k_move_offset}); + move_tile_window(bias_dram_window, {0, kN0}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -487,7 +452,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, + physical_seqlen_k_start_ = physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -512,26 +477,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS number{}); if(need_perpixel_check) { - auto apply_mask = [&](auto&& mask_func) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask_func(row, col - kv_l2p_offset); - }); - }; - - if constexpr(kHasSink) - { - apply_mask( - [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); - } - else - { - apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); - } + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); + }); } } @@ -696,12 +647,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); - physical_next_block_id_v = - amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( - i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); - i_page_block_v = v_page_block_navigator.move_tile_window( - i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); + i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 48ddb2e3fc8..cc0851efb3a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -62,7 +62,6 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static constexpr bool kHasSink = Traits::kHasSink; }; template {}), number{}, number{}); - const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) - return mask.GetSinkTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}); - else - { - auto [start, end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - return ck_tile::make_tuple(0, start, end); - } - }(); - const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); - const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); - const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - - const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; - const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); - const auto num_total_loop = - integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -279,22 +262,22 @@ struct BlockFmhaPipelineQRKSVS auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {kv_load_start, 0}); + {seqlen_k_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), kv_load_start}, // M/N + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, kv_load_start); + randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, kv_load_start}, // TODO: hdim split? + {0, seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -467,11 +450,6 @@ struct BlockFmhaPipelineQRKSVS #endif } } - if constexpr(kHasSink) - { - if(i_total_loops == 0) - move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); - } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -482,34 +460,17 @@ struct BlockFmhaPipelineQRKSVS number{}); if(need_perpixel_check) { - auto apply_mask = [&](auto&& mask_func) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !mask_func(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); - }; - - if constexpr(kHasSink) - { - apply_mask([&](auto&&... args) { - return variant.LogitsSinkMask(std::forward(args)...); - }); - } - else - { - apply_mask([&](auto&&... args) { - return variant.LogitsMask(std::forward(args)...); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); - } } } @@ -619,23 +580,11 @@ struct BlockFmhaPipelineQRKSVS if constexpr(kHasDropout) { + // K and dropout use the same address in LDS, finish loading from k_lds_window by + // gemm_0 to reuse LDS. block_sync_lds(); - auto randval_ptr = reinterpret_cast(smem_ptr); - - index_t seq_offset = [&]() { - if constexpr(!kHasSink) - return seqlen_k_start + i_total_loops * kN0; - - const bool in_sink_phase = (num_sink_loop > i_total_loops); - if(i_total_loops == num_sink_loop) - move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); - - return in_sink_phase ? (kv_load_start + i_total_loops * kN0) - : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); - }(); - dropout.template Run( - randval_ptr, seq_offset, p_compute, randval_dram_window); + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); } block_sync_lds(); @@ -687,14 +636,6 @@ struct BlockFmhaPipelineQRKSVS }); } // move K tile windows - if constexpr(kHasSink) - { - if(i_total_loops == 0) - { - move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); - move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); - } - } move_tile_window(k_dram_block_window, {kN0, 0}); // tail { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 34e347ef2b7..b67c28401f0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -62,7 +62,6 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; - static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -278,26 +277,11 @@ struct BlockFmhaPipelineQRKSVSAsync clear_tile(l); __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto tile_range_result = [&mask, &q_origin]() { - if constexpr(kHasSink) - return mask.GetSinkTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}); - else - { - auto [start, end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - return ck_tile::make_tuple(0, start, end); - } - }(); - const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); - const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); - const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; - const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); - const auto num_total_loop = - integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -325,7 +309,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {kv_load_start, 0}); + {seqlen_k_start, 0}); auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -348,16 +332,16 @@ struct BlockFmhaPipelineQRKSVSAsync auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), kv_load_start}, // M/N + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, kv_load_start); + randval_dram_block_window_tmp, seqlen_k_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, kv_load_start}, // TODO: hdim split? + {0, seqlen_k_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // prefetch K tile @@ -494,11 +478,6 @@ struct BlockFmhaPipelineQRKSVSAsync #endif } } - if constexpr(kHasSink) - { - if(i_total_loops == 0) - move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); - } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -510,34 +489,17 @@ struct BlockFmhaPipelineQRKSVSAsync if(need_perpixel_check) { - auto apply_mask = [&](auto&& mask_func) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !mask_func(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); - }; - - if constexpr(kHasSink) - { - apply_mask([&](auto&&... args) { - return variant.LogitsSinkMask(std::forward(args)...); - }); - } - else - { - apply_mask([&](auto&&... args) { - return variant.LogitsMask(std::forward(args)...); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); - } } } @@ -685,21 +647,11 @@ struct BlockFmhaPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); - - index_t seq_offset = [&]() { - if constexpr(!kHasSink) - return seqlen_k_start + i_total_loops * kN0; - - const bool in_sink_phase = (num_sink_loop > i_total_loops); - if(i_total_loops == num_sink_loop) - move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); - - return in_sink_phase ? (kv_load_start + i_total_loops * kN0) - : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); - }(); - dropout.template Run( - randval_ptr, seq_offset, p_compute, randval_dram_window); + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); } const auto p = [&]() { @@ -765,16 +717,8 @@ struct BlockFmhaPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - if constexpr(kHasSink) - { - if(i_total_loops == 0) - { - move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); - move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); - } - } + // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index e837ca49f2a..08fc42a4716 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -69,7 +69,6 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; - static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 77f216c72dd..59267fa3b17 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -19,9 +19,8 @@ template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -36,7 +35,6 @@ struct TileFmhaTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; }; template 1 or fwd training is running */ bool kIsPagedKV_, bool kDoFp8StaticQuant_, - index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ - bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ - bool kHasSink_ = false> + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> struct TileFmhaFwdPagedKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -83,7 +80,6 @@ struct TileFmhaFwdPagedKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; - static constexpr bool kHasSink = kHasSink_; }; template + index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -116,7 +111,6 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; - static constexpr bool kHasSink = kHasSink_; }; template