diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index ce914b92afb..e35c7f9c369 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS # there is no corresponding instance for parameters). if(BUILD_TESTING) # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill - list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*) endif() # generate a list of kernels, but not actually emit files at config sta diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2acc4674108..6ef77a7c453 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -66,7 +66,8 @@ {F_dropout}, {F_squant}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -103,7 +104,7 @@ ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip},{F_sink}>; template<> float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) @@ -190,9 +191,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; return fmha_fwd_(s, a); }} """ @@ -239,13 +240,14 @@ class FmhaFwdApiTrait: dvpad: str skip: str tr_load: str + sink: str constraint: CppConstraint @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -345,6 +347,7 @@ class FmhaFwdPipeline: F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false + F_sink: str # true/false F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -415,6 +418,10 @@ def pad_name() -> str: n += "_trload" else: n += "_ntrload" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -462,6 +469,7 @@ def api(self) -> str: F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), @@ -588,6 +596,7 @@ def template(self) -> str: F_mode=MODE_MAP[self.F_mode], F_pipeline=PIPELINE_MAP[self.F_pipeline.tag], F_trload=BOOL_MAP[self.F_pipeline.F_trload], + F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -630,6 +639,7 @@ def api_trait(self) -> FmhaFwdApiTrait: dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, + sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -696,49 +706,51 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines = [] if dtype in ["fp32"]: squant = "f" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip elif dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if hdim == 256 and hdim_v == 256: - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip else: - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip # TODO: cover arbitraty hdim# fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "bf8"]: # TODO None @@ -757,13 +769,14 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli ) if dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): if ( (hdim, hdim_v) in [(64, 64), (128, 128)] @@ -772,8 +785,8 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli and dropout == "f" and skip == "f" ): - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "t", sink)) # fmt: skip return pipelines @@ -811,23 +824,24 @@ def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeli pipelines = [] if dtype in ["fp16", "bf16"]: squant = "f" - for logits, mask, bias, lse, dropout, skip in itertools.product( + for logits, mask, bias, lse, dropout, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask, skip, "f", sink)) # fmt: skip elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, squant, mask, bias in itertools.product( ["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", squant, mask, "f", "f", "f")) # fmt: skip else: assert False return pipelines @@ -934,6 +948,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -945,6 +960,7 @@ def get_fwd_blobs( cond &= mode == "batch" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration @@ -985,6 +1001,7 @@ def get_fwd_blobs( cond = dtype == "fp32" cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # fp32 only, minimal set of parameters @@ -998,6 +1015,7 @@ def get_fwd_blobs( cond &= pipeline.F_skip == "f" cond &= pipeline.F_logits == "f" cond &= pipeline.F_mask == "s_no" + cond &= pipeline.F_sink == "f" if not cond: continue else: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 85c25561eab..5029f1fa97c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -74,7 +74,8 @@ {F_pagedkv}, kHasUnevenSplits, kMergeNumHeadGroupsSeqLenQ, - {F_occupancy}>; + {F_occupancy}, + {F_sink}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::QDataType, @@ -118,7 +119,7 @@ }} // anonymous namespace using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #pragma clang diagnostic push @@ -280,8 +281,8 @@ """ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -333,6 +334,7 @@ class FmhaFwdSplitKVApiTrait: dpad: str dvpad: str pagedkv: str + sink: str # sink or not bn1comb: int # tile size along v head_dim of combine kernel @property @@ -340,7 +342,7 @@ def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" - + f"{self.dvpad}-{self.pagedkv}" + + f"{self.dvpad}-{self.pagedkv}-{self.sink}" ) @property @@ -426,6 +428,7 @@ class FmhaFwdSplitKVPipeline: F_lse: str # F_squant: str # F_pagedkv: str # t/f + F_sink: str # t/f F_mask: str # value from MASK_MAP @property @@ -486,6 +489,10 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -568,6 +575,7 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_sink=BOOL_MAP[trait.sink], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, @@ -668,6 +676,7 @@ def template(self) -> str: F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_occupancy=self.F_tile.F_occupancy, + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode=MODE_MAP[self.F_mode], @@ -741,19 +750,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv in itertools.product( - ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"] + for logits, mask, bias, pagedkv, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + BIAS_MAP.keys(), + ["t", "f"], + ["t", "f"], ): - pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip elif dtype in ["fp8", "bf8"]: for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip - pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip + pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: # TODO None @@ -909,6 +922,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_vlayout == "row" cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -918,6 +932,7 @@ def get_fwd_splitkv_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= mode == "batch" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_varlen_fwd) integration @@ -1076,6 +1091,7 @@ def write_blobs( lse=kernel.F_pipeline.F_lse, squant=kernel.F_pipeline.F_squant, pagedkv=kernel.F_pipeline.F_pagedkv, + sink=kernel.F_pipeline.F_sink, spad=kernel.F_pipeline.F_spad, skpad=kernel.F_pipeline.F_skpad, dpad=kernel.F_pipeline.F_dpad, diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 17ac129e641..3ff47b940a3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -66,7 +66,8 @@ {F_pagedkv}, //pagedkv {F_squant}, {F_occupancy}, - {F_skip}>; + {F_skip}, + {F_sink}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -101,7 +102,7 @@ ck_tile::FmhaFwdPagedKVKernel; using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>; template<> float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) @@ -130,9 +131,9 @@ }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && +FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>; return fmha_fwd_pagedkv_(s, a); }} """ @@ -164,12 +165,13 @@ class FmhaFwdApiTrait: dpad: str dvpad: str skip: str + sink: str @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" - + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}" + + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @property @@ -257,6 +259,7 @@ class FmhaFwdPipeline: F_squant: str # F_mask: str # value from MASK_MAP F_skip: str # true/false + F_sink: str # true/false @property def name(self) -> str: @@ -321,6 +324,10 @@ def pad_name() -> str: n += "_pagedkv" else: n += "_npagedkv" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" return n @@ -364,6 +371,7 @@ def api(self) -> str: F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_sink=BOOL_MAP[trait.sink], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, @@ -481,6 +489,7 @@ def template(self) -> str: F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv], F_squant=BOOL_MAP[self.F_pipeline.F_squant], F_skip=BOOL_MAP[self.F_pipeline.F_skip], + F_sink=BOOL_MAP[self.F_pipeline.F_sink], F_occupancy=self.F_tile.F_occupancy, F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -527,6 +536,7 @@ def api_trait(self) -> FmhaFwdApiTrait: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + sink=self.F_pipeline.F_sink, ) @@ -540,22 +550,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]: squant = "t" if dtype == "fp8" else "f" pipelines = [] if dtype in ["fp16", "bf16"]: - for logits, mask, bias, pagedkv, skip in itertools.product( + for logits, mask, bias, pagedkv, skip, sink in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"], + ["t", "f"], ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip elif dtype in ["fp8", "bf8"]: # no need lse/dropout kernels for logits, mask, bias in itertools.product( ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys() ): - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip - pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip elif dtype in ["fp8fp16", "fp8bf16"]: pass # TODO else: @@ -679,6 +690,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "alibi"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # PyTorch integration @@ -688,6 +700,7 @@ def get_fwd_blobs( cond &= pipeline.F_bias in ["no", "bias"] cond &= pipeline.F_squant == "f" cond &= pipeline.F_skip == "f" + cond &= pipeline.F_sink == "f" if not cond: continue # Aiter(mha_fwd) integration diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index a952800806e..b95148cbc92 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -265,6 +265,7 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; @@ -351,6 +352,7 @@ struct fmha_fwd_pagedkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; }; @@ -441,6 +443,7 @@ struct fmha_fwd_splitkv_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; + ck_tile::index_t sink_size; ck_tile::index_t mask_type; }; @@ -611,6 +614,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q, args.p_drop, @@ -660,6 +664,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.p_drop, args.s_randval, @@ -727,6 +732,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_v, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type, args.min_seqlen_q); } @@ -772,6 +778,7 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) args.batch_stride_o, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -838,6 +845,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } else @@ -885,6 +893,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.split_stride_o_acc, args.window_size_left, args.window_size_right, + args.sink_size, args.mask_type); } }(); @@ -1131,7 +1140,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1157,6 +1167,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadDv = kPadDv_; static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1183,7 +1194,8 @@ template + bool kSkipMinSeqlenQ_ = false, + bool kHasSink_ = false> struct fmha_fwd_pagedkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -1208,6 +1220,7 @@ struct fmha_fwd_pagedkv_traits_ static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template @@ -1230,6 +1243,7 @@ template @@ -1343,6 +1358,7 @@ struct fmha_fwd_traits bool has_dropout; bool do_fp8_static_quant; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1361,6 +1377,7 @@ struct fmha_fwd_pagedkv_traits bool use_pagedkv = true; bool do_fp8_static_quant = false; bool skip_min_seqlen_q = false; + bool has_sink = false; // TODO: padding check is inside this api }; @@ -1380,6 +1397,7 @@ struct fmha_fwd_splitkv_traits bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; + bool has_sink = false; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 8a663d038d1..a3fc7a9611e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -907,6 +907,7 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; + traits.has_sink = mask.sink > 0 ? true : false; traits.has_lse = lse; traits.do_fp8_static_quant = squant; @@ -1072,6 +1073,7 @@ fwd_result fmha_fwd_run(mode_enum mode, args.window_size_left = mask.left; args.window_size_right = mask.right; + args.sink_size = mask.sink; args.mask_type = static_cast(mask.type); if constexpr(std::is_same_v>) @@ -1660,7 +1662,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::reference_batched_masking( s_host_ref, ck_tile::make_generic_attention_mask_from_lr_window( - mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + mask.left, mask.right, mask.sink, real_seqlen_q, real_seqlen_k)); } else { @@ -1672,6 +1674,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); @@ -1681,6 +1684,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::make_generic_attention_mask_from_lr_window( mask.left, mask.right, + mask.sink, real_seqlen_q, real_seqlen_k, mask.type == mask_enum::mask_top_left)); diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 2dfe0e7c529..aa30db0d6f1 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -25,6 +25,7 @@ struct mask_info ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right + ck_tile::index_t sink; void serialize(std::ostream& os) const { @@ -58,13 +59,14 @@ struct mask_info ck_tile::index_t window_size = std::stoi(v); ck_tile::index_t left_size = -1; ck_tile::index_t right_size = 0; + ck_tile::index_t sink_size = 0; if(window_size > 0) { left_size = window_size / 2; right_size = window_size - 1 - left_size; } auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, t == "xt"); + left_size, right_size, sink_size, y_total, x_total, t == "xt"); tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; tmp.y = r.at(ck_tile::number<0>{}); @@ -79,27 +81,54 @@ struct mask_info { throw std::invalid_argument("invalid mask value: " + str); } - ck_tile::index_t v0 = std::stoi(v.substr(0, found_1)); - ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1)); + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + auto found_2 = v.find(',', found_1 + 1); + ck_tile::index_t v1 = 0; + ck_tile::index_t sink = 0; + // ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation if(t == "t") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_top_left; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, true); + v0, v1, sink, y_total, x_total, true); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "b") { + if(found_2 != std::string::npos) + { + v1 = atoi(v.substr(found_1 + 1, found_2 - found_1 - 1).c_str()); + sink = atoi(v.substr(found_2 + 1).c_str()); + } + else + { + v1 = atoi(v.substr(found_1 + 1).c_str()); + sink = 0; + } tmp.type = mask_enum::mask_bottom_right; auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( - v0, v1, y_total, x_total, false); + v0, v1, sink, y_total, x_total, false); tmp.y = r.at(ck_tile::number<0>{}); tmp.x = r.at(ck_tile::number<1>{}); tmp.left = v0; tmp.right = v1; + tmp.sink = sink; } else if(t == "g") { @@ -108,6 +137,7 @@ struct mask_info tmp.x = v1; tmp.left = v0; // TODO: don't use this? tmp.right = v1; + tmp.sink = 0; } } else @@ -126,6 +156,7 @@ struct mask_info tmp.x = 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else if(str == "2" || str == "b") { @@ -134,6 +165,7 @@ struct mask_info tmp.x = seqlen_k - seqlen_q + 1; tmp.left = -1; tmp.right = 0; + tmp.sink = 0; } else { diff --git a/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh new file mode 100644 index 00000000000..712db522580 --- /dev/null +++ b/example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "fp16"; do + for mode in 1 0 ; do + for perm in 0 1 ; do + for vlayout in "r" "c" ; do + for batch in 1 4; do + for head in 1; do + for h_k in 1; do + for q_seq in 128 512 ; do + for kv_seq in 128 1024; do + for hdim in 32 64 128 256; do #256 + for lse in 0 1 ; do + for bias in "e" ; do + for p_drop in 0.0 0.2; do # 0.0 + for mask in "t:2,0,4" "b:1,0,2"; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=$batch -h=$head -h_k=$h_k -d=16, -d_v=$hdim -s=$q_seq -s_k=$kv_seq -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS -mask=$mask + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; done +} + + +set -x + +run_fp16_bf16_tests + +set +x diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 5c2a5a4b3d0..751fded2de1 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -36,6 +36,7 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh new file mode 100755 index 00000000000..b554e16ea7f --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" +set -euo pipefail + +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" +KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi +set -x + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 + +# window_size[2,0], sink_size = 2 + +# x=1/y=3 +# 1 * * * * * * * 1 * * * * * * * +# 1 1 * * * * * * 1 1 * * * * * * +# 1 1 1 * * * * * ----> 1 1 1 * * * * * +# * 1 1 1 * * * * 1 1 1 1 * * * * +# * * 1 1 1 * * * 1 1 1 1 1 * * * +# * * * 1 1 1 * * 1 1 * 1 1 1 * * +# * * * * 1 1 1 * 1 1 * * 1 1 1 * +# * * * * * 1 1 1 1 1 * * * 1 1 1 +# l=2/r=0(tl) l=2/r=0/s=2(tl) + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 + +# x=4/y=1 +# 1 1 1 1 * * * * 1 1 1 1 * * * * +# * 1 1 1 1 * * * 1 1 1 1 1 * * * +# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * +# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * +# * * * * 1 1 1 1 1 1 * * 1 1 1 1 +# l=0/r=3(tl) l=0/r=3/s=2(tl) +# l=3/r=0(br) l=3/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 + +# x=4/y=-1 +# * * 1 1 * * * * 1 1 1 1 * * * * +# * * * 1 1 * * * 1 1 * 1 1 * * * +# * * * * 1 1 * * ----> 1 1 * * 1 1 * * +# * * * * * 1 1 * 1 1 * * * 1 1 * +# * * * * * * 1 1 1 1 * * * * 1 1 +# l=1/r=0(br) l=1/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 + +# x=-1/y=5 + +# * * * * * * * * * * * * +# * * * * * * * * * * * * +# 1 * * * * * 1 * * * * * +# 1 1 * * * * 1 1 * * * * +# 1 1 1 * * * ----> 1 1 1 * * * +# * 1 1 1 * * 1 1 1 1 * * +# * * 1 1 1 * 1 1 1 1 1 * +# * * * 1 1 1 1 1 * 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + + +$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 +# x=-1/y=8 +# * * * * * * * * * * +# * * * * * * * * * * +# 1 * * * * ----> 1 * * * * +# 1 1 * * * 1 1 * * * +# 1 1 1 * * 1 1 1 * * +# 1 1 1 1 * 1 1 1 1 * +# 1 1 1 1 1 1 1 1 1 1 +# 1 1 1 1 1 1 1 1 1 1 +# l=2/r=0(br) l=2/r=0/s=2(br) + \ No newline at end of file diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index eece7fc3a81..93329e99ce6 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -20,7 +20,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor& c_b_m_n, cons { for(int m = 0; m < M; ++m) { - if(mask.IsOutOfBound(m, n)) + if(mask.IsOutOfSinkBound(m, n)) c_b_m_n(batch, m, n) = -ck_tile::numeric::infinity(); } } diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 2c45945fac0..5484c92f014 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -86,21 +86,22 @@ struct GenericAttentionMask static constexpr const char* name = impl::MaskName::name; CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_) - : GenericAttentionMask(0, 0, y_total_, x_total_) + : GenericAttentionMask(0, 0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + GenericAttentionMask(index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) + sink(mask_coord.at(number<2>{})), + y_total(mask_coord.at(number<3>{})), + x_total(mask_coord.at(number<4>{})) { } @@ -141,6 +142,44 @@ struct GenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto + GetSinkTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, 0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + if constexpr(IsLocal) + { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + } + else + { + return 0; + } + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; + if(x_start <= sink_seq_end && sink > 0) + return ck_tile::make_tuple(0, 0, x_end); + else + return ck_tile::make_tuple(sink_seq_end, x_start, x_end); + } + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -195,6 +234,30 @@ struct GenericAttentionMask } } + CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + return i_x >= x_total; + // no need to do min/max here, since i_x will never be < 0 or >= x_total + index_t x_start = -y + i_y + 1; + index_t x_end = min(i_y + x, x_total); + + if constexpr(IsLocal) + { + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x < x_start || i_x >= x_end; + } + else + { + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x >= x_end || i_y >= y_total; + } + } + // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() @@ -237,7 +300,7 @@ struct GenericAttentionMask } private: - index_t y, x; + index_t y, x, sink; index_t y_total, x_total; }; @@ -260,21 +323,23 @@ struct SimplifiedGenericAttentionMask static constexpr const char* name = impl::SimplifiedMaskName::name; CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_) - : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_) + : SimplifiedGenericAttentionMask(0, 0, 0, y_total_, x_total_) { } CK_TILE_HOST_DEVICE - SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_) - : y(y_), x(x_), y_total(y_total_), x_total(x_total_) + SimplifiedGenericAttentionMask( + index_t y_, index_t x_, index_t sink_, index_t y_total_, index_t x_total_) + : y(y_), x(x_), sink(sink_), y_total(y_total_), x_total(x_total_) { } template CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord) : y(mask_coord.at(number<0>{})), x(mask_coord.at(number<1>{})), - y_total(mask_coord.at(number<2>{})), - x_total(mask_coord.at(number<3>{})) + sink(mask_coord.at(number<2>{})), + y_total(mask_coord.at(number<3>{})), + x_total(mask_coord.at(number<4>{})) { } @@ -308,6 +373,38 @@ struct SimplifiedGenericAttentionMask } } + template + CK_TILE_HOST_DEVICE constexpr auto + GetSinkTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, 0, x_total); + } + else + { + // get the tile start/end range assum we loop over along X tile by tile + index_t x_start = [&]() { + index_t tmp = max(-y + i_y + 1, 0); + return (tmp / XTile) * XTile; // round to tile aligned + }(); + + // TODO: end could be negative, we ignore clamp here, and let caller to check + // ... in which case end-start is negative + index_t x_end = [&]() { + index_t tmp = min(i_y + YTile - 1 + x, x_total); + return ((tmp + XTile - 1) / XTile) * XTile; + }(); + + index_t sink_seq_end = sink > 0 ? ((sink + XTile - 1) / XTile) * XTile : 0; + + if(x_start <= sink_seq_end && sink > 0) + return ck_tile::make_tuple(0, 0, x_end); + else + return ck_tile::make_tuple(sink_seq_end, x_start, x_end); + } + } + template CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number height, @@ -325,6 +422,29 @@ struct SimplifiedGenericAttentionMask ck_tile::min(origin_end, split_end)); } + template + CK_TILE_HOST_DEVICE constexpr auto GetSinkTileRangeAlongX(index_t i_y, + number height, + number width, + index_t num_splits, + index_t i_split) const + { + auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); + const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t split_start = x_per_split * i_split; // 128 + const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); // 256 + const index_t sink_seq_end = sink > 0 ? ((sink + width - 1) / width) * width : 0; + const index_t start = ck_tile::max(origin_start, split_start); + const index_t end = ck_tile::min(origin_end, split_end); + const bool is_first_intersecting_split = + (split_start <= origin_start && split_end >= origin_start); + const bool sink_in_range = (sink_seq_end <= start); + + const index_t sink_offset = + (is_first_intersecting_split && sink_in_range) ? sink_seq_end : 0; + return ck_tile::make_tuple(sink_offset, start, end); + } + // to get the loop length along Y axis, return index:[start, end), end-start=length // use this if need loop over Y axis tile by tile (like q-seqlen loopover) // TODO: y_end still could be negative, so end-start could be negative(need check) @@ -368,11 +488,22 @@ struct SimplifiedGenericAttentionMask { index_t x_start = -y + i_y + 1; // this could be negative, but it's fine index_t x_end = min(i_y + x, x_total); // need min in case x is padded - return i_x < x_start || i_x >= x_end || i_y >= y_total; } } + CK_TILE_HOST_DEVICE constexpr auto IsOutOfSinkBound(index_t i_y, index_t i_x) const + { + if constexpr(!IsMasking) + return i_x >= x_total; + index_t x_start = -y + i_y + 1; // this could be negative, but it's fine + index_t x_end = min(i_y + x, x_total); // need min in case x is padded + if((i_x < sink) && (y < y_total) && ((i_y + x) > 1) && i_y < x_total) + return false; + else + return i_x < x_start || i_x >= x_end || i_y >= y_total; + } + // if current tile is at the edge, means need per-pixel mask check. // otherwise no need to check per-pixel // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y() @@ -406,7 +537,7 @@ struct SimplifiedGenericAttentionMask } private: - index_t y, x; + index_t y, x, sink; index_t y_total, x_total; }; @@ -607,6 +738,7 @@ struct SimplifiedRatioAttentionMask CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, + index_t sink_size, index_t y_total, index_t x_total, bool is_top_left = true) @@ -624,7 +756,21 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t x = 1 + right_size + x_tmp; index_t y = 1 + left_size + y_tmp; - return ck_tile::make_tuple(y, x, y_total, x_total); + return ck_tile::make_tuple(y, x, sink_size, y_total, x_total); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_generic_attention_mask_from_lr_window(index_t left_size, + index_t right_size, + index_t sink_size, + index_t y_total, + index_t x_total, + bool is_top_left = true) +{ + auto r = make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, sink_size, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), sink_size, y_total, x_total}; } template @@ -636,7 +782,7 @@ make_generic_attention_mask_from_lr_window(index_t left_size, bool is_top_left = true) { auto r = make_generic_attention_mask_coordinates_from_lr_window( - left_size, right_size, y_total, x_total, is_top_left); - return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total}; + left_size, right_size, 0, y_total, x_total, is_top_left); + return MaskType{r.at(number<0>{}), r.at(number<1>{}), 0, y_total, x_total}; } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp index d8b0cdbb86b..245f5dc5682 100644 --- a/include/ck_tile/ops/fmha/block/variants.hpp +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -162,6 +162,17 @@ struct StandardAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; template @@ -224,6 +235,17 @@ struct LogitsSoftCap { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; constexpr uint32_t CUSTOM_MASK = 1U; @@ -297,6 +319,17 @@ struct ComposedAttention { return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); } + + template + __device__ __forceinline__ bool LogitsSinkMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfSinkBound(qo_idx, kv_idx); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index 3b476299e15..cd5b180a39d 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -198,7 +198,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -362,6 +362,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -425,6 +426,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -509,6 +511,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -570,6 +573,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -1026,6 +1030,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fba3065842a..414eae3651e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -56,6 +56,7 @@ struct FmhaFwdKernel static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -112,7 +113,7 @@ struct FmhaFwdKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload"); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload") + (kHasSink ? "_sink" : "_nsink"); #undef _SS_ #undef _TS_ // clang-format on @@ -200,7 +201,7 @@ struct FmhaFwdKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -374,6 +375,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -432,6 +434,7 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -518,6 +521,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -565,6 +569,7 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type, p_drop, s_randval, @@ -615,6 +620,7 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, float p_drop, bool s_randval, @@ -662,6 +668,7 @@ struct FmhaFwdKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type, p_drop, s_randval, @@ -706,6 +713,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -765,6 +773,7 @@ struct FmhaFwdKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -848,6 +857,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -891,6 +901,7 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q, p_drop, @@ -937,6 +948,7 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, @@ -980,6 +992,7 @@ struct FmhaFwdKernel nhead_stride_o, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q, p_drop, @@ -1471,6 +1484,7 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); @@ -2200,6 +2214,7 @@ struct FmhaFwdKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index a2e6f083616..712892387cc 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -55,6 +55,7 @@ struct FmhaFwdPagedKVKernel static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr bool kHasSink = FmhaPipeline::kHasSink; using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; @@ -101,7 +102,7 @@ struct FmhaFwdPagedKVKernel (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); #undef _SS_ #undef _TS_ // clang-format on @@ -189,7 +190,7 @@ struct FmhaFwdPagedKVKernel struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -326,6 +327,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -379,6 +381,7 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -453,6 +456,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { return MakeKargsImpl(q_ptr, @@ -495,6 +499,7 @@ struct FmhaFwdPagedKVKernel batch_stride_o, window_size_left, window_size_right, + sink_size, mask_type); } @@ -536,6 +541,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -590,6 +596,7 @@ struct FmhaFwdPagedKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kStoreLSE) @@ -660,6 +667,7 @@ struct FmhaFwdPagedKVKernel ck_tile::index_t batch_stride_v, // only used for paged-kvcache ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q) { @@ -699,6 +707,7 @@ struct FmhaFwdPagedKVKernel batch_stride_v, window_size_left, window_size_right, + sink_size, mask_type, min_seqlen_q); } @@ -1257,6 +1266,7 @@ struct FmhaFwdPagedKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index a6e44c7293c..1b5deb26e0c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -51,6 +51,7 @@ struct FmhaFwdSplitKVKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; + static constexpr bool kHasSink = FmhaPipeline::Problem::kHasSink; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; using AttentionVariant = ck_tile::remove_cvref_t; @@ -101,7 +102,7 @@ struct FmhaFwdSplitKVKernel "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + - (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ) + (kHasSink ? "_sink" : "_nsink" ); #undef _SS_ #undef _TS_ // clang-format on @@ -198,7 +199,7 @@ struct FmhaFwdSplitKVKernel struct MaskKargs { // ck_tile::index_t window_size_left, window_size_right; - ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right, sink_size; ck_tile::GenericAttentionMaskEnum mask_type; }; @@ -325,6 +326,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -384,6 +386,7 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -451,6 +454,7 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, + ck_tile::index_t sink_size, ck_tile::index_t mask_type) { Kargs kargs{{q_ptr, @@ -508,6 +512,7 @@ struct FmhaFwdSplitKVKernel { kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; + kargs.sink_size = sink_size; kargs.mask_type = static_cast(mask_type); } if constexpr(kDoFp8StaticQuant) @@ -994,6 +999,7 @@ struct FmhaFwdSplitKVKernel return ck_tile::make_generic_attention_mask_from_lr_window( kargs.window_size_left, kargs.window_size_right, + kargs.sink_size, kargs.seqlen_q, kargs.seqlen_k, kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 7a8e9a1d470..d73c673a58a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -57,6 +57,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -228,10 +229,22 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS clear_tile(o_acc); set_tile(m, -numeric::infinity()); clear_tile(l); - - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -255,7 +268,6 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return o_acc; } } - // k_dram_block_window const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset; const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset; @@ -274,27 +286,36 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); // v_dram_window auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); - auto q_tile = tile_elementwise_in(q_element_func, q); // prefetch K tile @@ -321,9 +342,16 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {kN0, 0})); + i_page_block_k, k_dram_block_window, {k_move_offset, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -442,7 +470,7 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); { const auto k_origin = k_page_block_navigator.to_global_window_origin( @@ -474,14 +502,29 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if(s_acc, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto row, auto col) { + return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } } @@ -647,7 +690,12 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); + physical_next_block_id_v = + amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 4d1c38e0790..65e0eb7dfb3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -57,6 +57,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -256,11 +257,23 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + else + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - // check early exit if no work to do + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) { const index_t logical_num_total_loop = @@ -304,24 +317,33 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // store Q into LDS @@ -369,7 +391,13 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // STAGE 1, QK gemm clear_tile(s_acc); // initialize C - + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); // load the second tile of the first iteration k_block_tile = load_tile(k_dram_window); @@ -494,7 +522,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -505,7 +533,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -530,12 +558,26 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); - }); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } @@ -546,7 +588,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS { // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}); k_dram_window = make_tile_window( k_dram_block_window, @@ -742,6 +784,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS // moving k_dram_window is an in-page-block operation, so there is // no need to invoke k_page_block_navigator.move_tile_window() here. move_tile_window(k_dram_window, {0, kK0}); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); } } while(++i_total_loops < num_total_loop); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index fe5e0bc3452..7eb8872022e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -56,6 +56,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kIsPagedKV = Problem::kIsPagedKV; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -229,9 +230,23 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS set_tile(m, -numeric::infinity()); clear_tile(l); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX( - q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + else + { + auto [start, end] = mask.GetTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto logical_seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto logical_seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits) @@ -274,24 +289,35 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS return physical_seqlen_k_start_; } }(); + const auto kv_load_start = (sink_seq_end == 0 && aligned_physical_seqlen_k_start > 0) + ? aligned_physical_seqlen_k_start + : 0; const index_t num_total_loop = - integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0); + integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0) + + num_sink_loop; auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window( - k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0}); + k_dram_block_window_lengths, {kv_load_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + + const index_t bias_n_offset = [&]() { + if constexpr(kHasSink) + return kv_load_start; + else + return logical_seqlen_k_start - + (physical_seqlen_k_start - aligned_physical_seqlen_k_start); + }(); + auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), - logical_seqlen_k_start - (physical_seqlen_k_start - - aligned_physical_seqlen_k_start)}, // M/N + {bias_origin.at(number<0>{}), bias_n_offset}, Policy::template MakeBiasDramTileDistribution()); auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -320,9 +346,18 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); k_block_tile = load_tile(k_dram_window); } + const bool is_sink_tile = ((num_sink_loop - 1) == i_total_loops); + + const auto k_move_offset = [&]() { + if constexpr(kHasSink) + return is_sink_tile ? logical_seqlen_k_start - sink_seq_end + kN0 : kN0; + else + return kN0; + }(); + auto physical_next_block_id_k = amd_wave_read_first_lane(k_page_block_navigator.prefetch_table_id( - i_page_block_k, k_dram_block_window, {kN0, 0})); + i_page_block_k, k_dram_block_window, {k_move_offset, 0})); auto physical_next_block_id_v = amd_wave_read_first_lane( v_page_block_navigator.prefetch_table_id(i_page_block_v, v_dram_window, {0, kK1})); @@ -441,7 +476,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS #endif } } - move_tile_window(bias_dram_window, {0, kN0}); + move_tile_window(bias_dram_window, {0, k_move_offset}); /// TODO: only check in first/last iteration without increasing code size if constexpr(kHasUnevenSplits) @@ -452,7 +487,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS s_acc, -numeric::infinity(), [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_start_ = is_sink_tile ? 0 : physical_seqlen_k_start, physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); if constexpr(kIsPagedKV) @@ -477,12 +512,26 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col - kv_l2p_offset); - }); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask_func(row, col - kv_l2p_offset); + }); + }; + + if constexpr(kHasSink) + { + apply_mask( + [&](auto row, auto col) { return mask.IsOutOfSinkBound(row, col); }); + } + else + { + apply_mask([&](auto row, auto col) { return mask.IsOutOfBound(row, col); }); + } } } @@ -647,7 +696,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } // move K tile windows i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}, physical_next_block_id_k); + i_page_block_k, k_dram_block_window, {k_move_offset, 0}, physical_next_block_id_k); + physical_next_block_id_v = + amd_wave_read_first_lane(v_page_block_navigator.prefetch_table_id( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0})); + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_window, {0, k_move_offset - kN0}, physical_next_block_id_v); // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index cc0851efb3a..48ddb2e3fc8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -62,6 +62,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kHasSink = Traits::kHasSink; }; template {}), number{}, number{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -262,22 +279,22 @@ struct BlockFmhaPipelineQRKSVS auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -450,6 +467,11 @@ struct BlockFmhaPipelineQRKSVS #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == 0) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -460,17 +482,34 @@ struct BlockFmhaPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask(std::forward(args)...); + }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); }); + } } } @@ -580,11 +619,23 @@ struct BlockFmhaPipelineQRKSVS if constexpr(kHasDropout) { - // K and dropout use the same address in LDS, finish loading from k_lds_window by - // gemm_0 to reuse LDS. block_sync_lds(); + auto randval_ptr = reinterpret_cast(smem_ptr); + + index_t seq_offset = [&]() { + if constexpr(!kHasSink) + return seqlen_k_start + i_total_loops * kN0; + + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); + + return in_sink_phase ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + }(); + dropout.template Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } block_sync_lds(); @@ -636,6 +687,14 @@ struct BlockFmhaPipelineQRKSVS }); } // move K tile windows + if constexpr(kHasSink) + { + if(i_total_loops == 0) + { + move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); + move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); + } + } move_tile_window(k_dram_block_window, {kN0, 0}); // tail { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b67c28401f0..34e347ef2b7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -62,6 +62,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || @@ -277,11 +278,26 @@ struct BlockFmhaPipelineQRKSVSAsync clear_tile(l); __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -309,7 +325,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); auto k_dram_window = make_tile_window( k_dram_block_window.get_bottom_tensor_view(), @@ -332,16 +348,16 @@ struct BlockFmhaPipelineQRKSVSAsync auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dram_window = make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? Policy::template MakeVDramTileDistribution()); // prefetch K tile @@ -478,6 +494,11 @@ struct BlockFmhaPipelineQRKSVSAsync #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == 0) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -489,17 +510,34 @@ struct BlockFmhaPipelineQRKSVSAsync if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask(std::forward(args)...); + }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); }); + } } } @@ -647,11 +685,21 @@ struct BlockFmhaPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + index_t seq_offset = [&]() { + if constexpr(!kHasSink) + return seqlen_k_start + i_total_loops * kN0; + + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, {0, seqlen_k_start - sink_seq_end}); + + return in_sink_phase ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + }(); + dropout.template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } const auto p = [&]() { @@ -717,8 +765,16 @@ struct BlockFmhaPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - // move K tile windows + if constexpr(kHasSink) + { + if(i_total_loops == 0) + { + move_tile_window(k_dram_block_window, {seqlen_k_start - sink_seq_end, 0}); + move_tile_window(v_dram_window, {0, seqlen_k_start - sink_seq_end}); + } + } move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 08fc42a4716..e837ca49f2a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -69,6 +69,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kHasUnevenSplits = true; + static constexpr bool kHasSink = Problem::kHasSink; static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 59267fa3b17..77f216c72dd 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -19,8 +19,9 @@ template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + bool kHasSink_ = false> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -35,6 +36,7 @@ struct TileFmhaTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template 1 or fwd training is running */ bool kIsPagedKV_, bool kDoFp8StaticQuant_, - index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ - bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */ + bool kHasSink_ = false> struct TileFmhaFwdPagedKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -80,6 +83,7 @@ struct TileFmhaFwdPagedKVTraits static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; + static constexpr bool kHasSink = kHasSink_; }; template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kHasSink_ = false> struct TileFmhaFwdSplitKVTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -111,6 +116,7 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_; static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kHasSink = kHasSink_; }; template