Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*_nsink*,*@*_nlogits*_nbias*_nsink*,*,*_nlogits*_nskip*_pagedkv*)
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
endif()

# generate a list of kernels, but not actually emit files at config sta
Expand Down
74 changes: 28 additions & 46 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Large diffs are not rendered by default.

42 changes: 13 additions & 29 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@
{F_pagedkv},
kHasUnevenSplits,
kMergeNumHeadGroupsSeqLenQ,
{F_occupancy},
{F_sink}>;
{F_occupancy}>;

using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
Expand Down Expand Up @@ -119,7 +118,7 @@
}} // anonymous namespace

using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_sink}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;

#pragma clang diagnostic push
Expand Down Expand Up @@ -281,8 +280,8 @@
"""

FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;

// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
Expand Down Expand Up @@ -334,15 +333,14 @@ class FmhaFwdSplitKVApiTrait:
dpad: str
dvpad: str
pagedkv: str
sink: str # sink or not
bn1comb: int # tile size along v head_dim of combine kernel

@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
+ f"{self.dvpad}-{self.pagedkv}"
)

@property
Expand Down Expand Up @@ -428,7 +426,6 @@ class FmhaFwdSplitKVPipeline:
F_lse: str #
F_squant: str #
F_pagedkv: str # t/f
F_sink: str # t/f
F_mask: str # value from MASK_MAP

@property
Expand Down Expand Up @@ -489,10 +486,6 @@ def pad_name() -> str:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"
return n


Expand Down Expand Up @@ -575,7 +568,6 @@ def api(self) -> str:
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_sink=BOOL_MAP[trait.sink],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
Expand Down Expand Up @@ -676,7 +668,6 @@ def template(self) -> str:
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
Expand Down Expand Up @@ -750,23 +741,19 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdSplitKVPipeline]:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, sink in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
for logits, mask, bias, pagedkv in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]
):
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, sink, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "f", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "t", "t", logits, bias, "t", squant, pagedkv, mask)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "f", "f", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
pipelines.append(Pipeline("qr", "row", "t", "t", "f", "f", logits, bias, "t", squant, "f", mask)) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
Expand Down Expand Up @@ -922,7 +909,6 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
Expand All @@ -932,7 +918,6 @@ def get_fwd_splitkv_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
Expand Down Expand Up @@ -1091,7 +1076,6 @@ def write_blobs(
lse=kernel.F_pipeline.F_lse,
squant=kernel.F_pipeline.F_squant,
pagedkv=kernel.F_pipeline.F_pagedkv,
sink=kernel.F_pipeline.F_sink,
spad=kernel.F_pipeline.F_spad,
skpad=kernel.F_pipeline.F_skpad,
dpad=kernel.F_pipeline.F_dpad,
Expand Down
33 changes: 10 additions & 23 deletions example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip},
{F_sink}>;
{F_skip}>;

using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;

Expand Down Expand Up @@ -102,7 +101,7 @@
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;

using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}, {F_sink}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;

template<>
float fmha_fwd_pagedkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
Expand Down Expand Up @@ -131,9 +130,9 @@
}}
"""

FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && (t.has_sink == {F_sink}) &&
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip},{F_sink}>;
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
Expand Down Expand Up @@ -165,13 +164,12 @@ class FmhaFwdApiTrait:
dpad: str
dvpad: str
skip: str
sink: str

@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}"
)

@property
Expand Down Expand Up @@ -259,7 +257,6 @@ class FmhaFwdPipeline:
F_squant: str #
F_mask: str # value from MASK_MAP
F_skip: str # true/false
F_sink: str # true/false

@property
def name(self) -> str:
Expand Down Expand Up @@ -324,10 +321,6 @@ def pad_name() -> str:
n += "_pagedkv"
else:
n += "_npagedkv"
if self.F_sink == "t":
n += "_sink"
else:
n += "_nsink"

return n

Expand Down Expand Up @@ -371,7 +364,6 @@ def api(self) -> str:
F_lse=BOOL_MAP[trait.lse],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_skip=BOOL_MAP[trait.skip],
F_sink=BOOL_MAP[trait.sink],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
Expand Down Expand Up @@ -489,7 +481,6 @@ def template(self) -> str:
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_skip=BOOL_MAP[self.F_pipeline.F_skip],
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
Expand Down Expand Up @@ -536,7 +527,6 @@ def api_trait(self) -> FmhaFwdApiTrait:
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip,
sink=self.F_pipeline.F_sink,
)


Expand All @@ -550,23 +540,22 @@ def get_pipelines(dtype, hdim, mask_impl) -> List[FmhaFwdPipeline]:
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, pagedkv, skip, sink in itertools.product(
for logits, mask, bias, pagedkv, skip in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t"],
["f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip, sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "f", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", pagedkv, squant, mask, skip)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(
["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()
):
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f", "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "f", "f", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_pagedkv", "row", "t", "t", "f", "f", logits, bias, "f", "t", squant, mask, "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
pass # TODO
else:
Expand Down Expand Up @@ -690,7 +679,6 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
Expand All @@ -700,7 +688,6 @@ def get_fwd_blobs(
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_skip == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
Expand Down
Loading