Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,10 @@ def get_pipelines(


class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
arch = ArchTrait("gfx11")
arch = ArchTrait(
"gfx11",
preprocessor_check="defined(__gfx11__) && !defined(__gfx115__)",
)

_DT_FP16_BF16 = ("fp16", "bf16")

Expand All @@ -1109,10 +1112,12 @@ def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
} # fmt: skip
else:
raise ValueError(f"unsupported dtype={dtype}")
Expand All @@ -1133,12 +1138,25 @@ def get_pipelines(
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
Comment thread
ex-rzr marked this conversation as resolved.
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines


class KernelComponentFactoryGfx115(KernelComponentFactoryGfx11):
arch = ArchTrait("gfx115")

@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = super().get_hdim_tile_size_dict(dtype)
if dtype in cls._DT_FP16_BF16:
result[(64, 64)] = [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip
result[(256, 256)] = [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)] # fmt: skip
return result


class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
arch = ArchTrait("gfx12")

Expand Down Expand Up @@ -1230,6 +1248,8 @@ def get_factory(target: str):
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9

if target.startswith("gfx115"):
return KernelComponentFactoryGfx115
if target.startswith("gfx11"):
return KernelComponentFactoryGfx11
if target.startswith("gfx12"):
Expand Down
1 change: 1 addition & 0 deletions projects/composablekernel/include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp"
#include "ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,9 @@ struct gfx103_t
struct gfx11_t
{
};
struct gfx115_t
{
};
struct gfx12_t
{
};
Expand Down Expand Up @@ -1174,6 +1177,8 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }

CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }

CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx115_t) { return 32; }

CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }

CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
Expand Down
3 changes: 3 additions & 0 deletions projects/composablekernel/include/ck_tile/core/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__)
#define __gfx11__
#endif
#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
#define __gfx115__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__
#endif
Expand Down
Loading