Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
a092cf2
enable attn sink
LJ-underdog Sep 22, 2025
4e6f2de
update attn_sink script
LJ-underdog Sep 22, 2025
e89fd09
fix some error
LJ-underdog Sep 22, 2025
14bef45
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 22, 2025
628082d
clang-format
LJ-underdog Sep 23, 2025
80be924
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 23, 2025
77a1107
update fmha_bwd mask
LJ-underdog Sep 24, 2025
33db67c
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 24, 2025
27b79f1
update fmha_bwd_kernel'mask
LJ-underdog Sep 24, 2025
8c659bc
update block_fmha_pipeline_qr_ks_vs.hpp
LJ-underdog Sep 24, 2025
ba22123
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 25, 2025
f83007d
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 26, 2025
7b51f1b
fix ci error
LJ-underdog Sep 28, 2025
9438f5f
Merge branch 'develop' into lj/attn_sink
LJ-underdog Sep 28, 2025
a946d48
fix format error
LJ-underdog Sep 28, 2025
4f06d30
Update block_fmha_bwd_pipeline_default_policy.hpp
LJ-underdog Sep 28, 2025
42af7c2
Update fmha_fwd_runner.hpp
LJ-underdog Sep 29, 2025
00e8ae8
Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
LJ-underdog Sep 29, 2025
aa9b1b1
Update fmha_fwd_runner.hpp
LJ-underdog Sep 29, 2025
171d16c
Update fmha_fwd_runner.hpp
LJ-underdog Sep 29, 2025
3290f92
Update fmha_fwd_runner.hpp
LJ-underdog Sep 29, 2025
edaef37
update splitkv_pipline
LJ-underdog Sep 30, 2025
3ea15ba
update splitkv&pagedkv pipeline
LJ-underdog Oct 23, 2025
d4d9698
add sink test
LJ-underdog Oct 23, 2025
e116fa0
update attn_sink result log
LJ-underdog Oct 24, 2025
fd7f676
Merge branch 'develop' into lj/attn_sink
LJ-underdog Oct 24, 2025
dd6f264
update smoke_test_fwd_sink.sh
LJ-underdog Oct 27, 2025
648cbf3
update test file
LJ-underdog Oct 28, 2025
0b634ab
update test script
LJ-underdog Oct 28, 2025
11e35ad
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 3, 2025
912bc7e
Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
LJ-underdog Nov 3, 2025
d06476f
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 3, 2025
794670a
use constexpr kHasSink for sink in fmha pipeline
LJ-underdog Nov 10, 2025
48bbd67
update by pre-commit
LJ-underdog Nov 10, 2025
569dab0
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 10, 2025
c657c1f
Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs…
LJ-underdog Nov 10, 2025
9463138
Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs…
LJ-underdog Nov 10, 2025
cdbbf55
Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp
LJ-underdog Nov 10, 2025
2007377
Update fmha_fwd.py
LJ-underdog Nov 10, 2025
6779577
Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
LJ-underdog Nov 10, 2025
fa3df18
Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipel…
LJ-underdog Nov 10, 2025
f772d03
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 13, 2025
3cffa0d
Remove causal mask setting logic from mask.hpp
LJ-underdog Nov 14, 2025
84ca45f
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 14, 2025
f08d861
fix ci error that some usage of lamada not support in c++17
LJ-underdog Nov 14, 2025
afc92b7
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 14, 2025
0c7cb3d
Update remod.py
LJ-underdog Nov 17, 2025
5ce1524
add smoke sink test
LJ-underdog Nov 17, 2025
14c05cb
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 17, 2025
2fe5243
Update fmha_pagedkv_prefill.py
LJ-underdog Nov 17, 2025
4f5116f
Update FmhaFwdPipeline parameters in fmha_fwd.py
LJ-underdog Nov 17, 2025
3259192
update block_fmha_pipeline_qr_ks_vs_async_trload.hpp
LJ-underdog Nov 17, 2025
79d0e45
fix c++17 unsupprot error
LJ-underdog Nov 18, 2025
bd06e29
Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp
LJ-underdog Nov 18, 2025
0debfd9
Fix formatting of sink_seq_end assignment
LJ-underdog Nov 18, 2025
6aee384
Fix indentation for sink_seq_end assignment
LJ-underdog Nov 18, 2025
e2b2254
Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp
LJ-underdog Nov 18, 2025
38fc5d6
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 18, 2025
7671121
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 19, 2025
e414747
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 19, 2025
c24ccab
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 20, 2025
660efeb
Merge branch 'develop' into lj/attn_sink
LJ-underdog Nov 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
endif()

# generate a list of kernels, but not actually emit files at config sta
Expand Down
74 changes: 46 additions & 28 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Large diffs are not rendered by default.

42 changes: 29 additions & 13 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@
{F_pagedkv},
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy}>;
{F_occupancy},
{F_sink}>;

using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
Expand Down Expand Up @@ -118,7 +119,7 @@
}} // anonymous namespace

using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;

#pragma clang diagnostic push
Expand Down Expand Up @@ -280,8 +281,8 @@
"""

FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;

// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
Expand Down Expand Up @@ -333,14 +334,15 @@ class FmhaFwdSplitKVApiTrait:
dpad: str
dvpad: str
pagedkv: str
sink: str # sink or not
bn1comb: int # tile size along v head_dim of combine kernel

@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
+ f"{self.dvpad}-{self.pagedkv}"
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
)

@property
Expand Down Expand Up @@ -426,6 +428,7 @@ class FmhaFwdSplitKVPipeline:
F_lse: str #
F_squant: str #
F_pagedkv: str # t/f
F_sink: str # t/f
F_mask: str # value from MASK_MAP

@property
Expand Down Expand Up @@ -486,6 +489,10 @@ def pad_name() -> str:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n


Expand Down Expand Up @@ -568,6 +575,7 @@ def api(self) -> str:
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
Expand Down Expand Up @@ -668,6 +676,7 @@ def template(self) -> str:
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
Expand Down Expand Up @@ -741,19 +750,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
for logits, mask, bias, pagedkv, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
):
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
Expand Down Expand Up @@ -909,6 +922,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
Expand All @@ -918,6 +932,7 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
Expand Down Expand Up @@ -1076,6 +1091,7 @@ def write_blobs(
lse=kernel.F_pipeline.F_lse,
squant=kernel.F_pipeline.F_squant,
pagedkv=kernel.F_pipeline.F_pagedkv,
sink=kernel.F_pipeline.F_sink,
spad=kernel.F_pipeline.F_spad,
skpad=kernel.F_pipeline.F_skpad,
dpad=kernel.F_pipeline.F_dpad,
Expand Down
33 changes: 23 additions & 10 deletions example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip}>;
{F_skip},
{F_sink}>;

using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;

Expand Down Expand Up @@ -101,7 +102,7 @@
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;

using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;

template<>
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
Expand Down Expand Up @@ -130,9 +131,9 @@
}}
"""

FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
Expand Down Expand Up @@ -164,12 +165,13 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
skip: str
sink: str

@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)

@property
Expand Down Expand Up @@ -257,6 +259,7 @@ class FmhaFwdPipeline:
F_squant: str #
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_sink: str # true/false

@property
def name(self) -> str:
Expand Down Expand Up @@ -321,6 +324,10 @@ def pad_name() -> str:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"

return n

Expand Down Expand Up @@ -364,6 +371,7 @@ def api(self) -> str:
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_sink=BOOL_MAP[trait.sink],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
Expand Down Expand Up @@ -481,6 +489,7 @@ def template(self) -> str:
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
Expand Down Expand Up @@ -527,6 +536,7 @@ def api_trait(self) -> FmhaFwdApiTrait:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
sink=self.F_pipeline.F_sink,
)


Expand All @@ -540,22 +550,23 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, skip in itertools.product(
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t"],
["f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
pass # TODO
else:
Expand Down Expand Up @@ -679,6 +690,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
Expand All @@ -688,6 +700,7 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
Expand Down
Loading