Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,47 +38,56 @@ class TileKernelInstance:

BlockPerCu: int # 1..BLOCK_PER_CU_MAX

# When True, 8-warp kernels read x_scale in row-major layout natively,
# skipping the host-side transpose.
AQRowMajor: bool = False

@property
def is_eight_warp(self) -> bool:
return self.M_Warp * self.N_Warp * self.K_Warp == 8 and self.K_Warp_Tile == 128

@property
def name(self) -> str:
"""
Generate a unique name for the kernel instance based on its parameters.
"""

return ("_").join(
[
"a8w8_blockscale_cktile",
("x").join(
map(
lambda x: str(x),
[self.M_Tile, self.N_Tile, self.K_Tile],
)
),
("x").join(
map(
lambda x: str(x),
[self.M_Warp, self.N_Warp, self.K_Warp],
)
),
("x").join(
map(
lambda x: str(x),
[self.M_Warp_Tile, self.N_Warp_Tile, self.K_Warp_Tile],
)
),
self.Scheduler.lower(),
("x").join(
map(
lambda x: str(int(x)),
[
self.TiledMMAPermuteN,
self.TransposeC,
self.UsePersistentKernel,
],
)
),
str(self.BlockPerCu),
]
)
parts = [
"a8w8_blockscale_cktile",
("x").join(
map(
lambda x: str(x),
[self.M_Tile, self.N_Tile, self.K_Tile],
)
),
("x").join(
map(
lambda x: str(x),
[self.M_Warp, self.N_Warp, self.K_Warp],
)
),
("x").join(
map(
lambda x: str(x),
[self.M_Warp_Tile, self.N_Warp_Tile, self.K_Warp_Tile],
)
),
self.Scheduler.lower(),
("x").join(
map(
lambda x: str(int(x)),
[
self.TiledMMAPermuteN,
self.TransposeC,
self.UsePersistentKernel,
],
)
),
str(self.BlockPerCu),
]
if self.AQRowMajor:
parts.append("aqrm")
return "_".join(parts)


BLOCK_PER_CU_MAX = 4
Expand Down Expand Up @@ -131,6 +140,8 @@ def expand_blockpercu(base_dict, max_bpc=BLOCK_PER_CU_MAX, field_name="BlockPerC
9: TileKernelInstance( 128, 128, 128, 1, 4, 1, 16, 16, 128, "Intrawave", False, True, False, 1 ),
10: TileKernelInstance( 128, 128, 128, 2, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 2 ),
11: TileKernelInstance( 192, 256, 128, 4, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 1 ),
# 8-warp kernel (4x2x1=8) with AQRowMajor=True: skip host-side x_scale transpose
12: TileKernelInstance( 192, 256, 128, 4, 2, 1, 16, 16, 128, "Intrawave", False, True, False, 1, AQRowMajor=True),
}

default_kernels_cktile_dict = {
Expand Down
3 changes: 2 additions & 1 deletion csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def gen_cktile_instance(self, k: TileKernelInstance):
{str(k.TransposeC).lower()},
{str(k.UsePersistentKernel).lower()},
ck_tile::GemmPipelineScheduler::{k.Scheduler},
{k.BlockPerCu}>;
{k.BlockPerCu},
{str(k.AQRowMajor).lower()}>;

// Run kernel instance.
return gemm_a8w8_blockscale_cktile_impl<DDataType, EDataType, TileGemmInstance>(XQ, WQ, x_scale, w_scale, Y, preshuffleB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ template <ck_tile::index_t M_Tile,
bool TransposeC = false,
bool UsePersistentKernel = false,
ck_tile::GemmPipelineScheduler Scheduler = ck_tile::GemmPipelineScheduler::Intrawave,
int BlockPerCu = 1>
int BlockPerCu = 1,
bool AQRowMajor = false>
struct CreateTileGemmConfig
{
static constexpr ck_tile::index_t M_Tile_v = M_Tile;
Expand All @@ -77,6 +78,7 @@ struct CreateTileGemmConfig
static constexpr bool UsePersistentKernel_v = UsePersistentKernel;
static constexpr ck_tile::GemmPipelineScheduler Scheduler_v = Scheduler;
static constexpr int BlockPerCu_v = BlockPerCu;
static constexpr bool AQRowMajor_v = AQRowMajor;
};

template <ck_tile::index_t M_Tile,
Expand All @@ -92,7 +94,8 @@ template <ck_tile::index_t M_Tile,
bool TransposeC = false,
bool UsePersistentKernel = false,
ck_tile::GemmPipelineScheduler Scheduler = ck_tile::GemmPipelineScheduler::Intrawave,
int BlockPerCu = 1>
int BlockPerCu = 1,
bool AQRowMajor = false>
using TileGemmConfig = CreateTileGemmConfig<M_Tile,
N_Tile,
K_Tile,
Expand All @@ -106,7 +109,8 @@ using TileGemmConfig = CreateTileGemmConfig<M_Tile,
TransposeC,
UsePersistentKernel,
Scheduler,
BlockPerCu>;
BlockPerCu,
AQRowMajor>;

template <typename QDataType,
typename OutDataType,
Expand All @@ -124,6 +128,9 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args)
BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp_v * GemmConfig::N_Warp_v * GemmConfig::K_Warp_v == 8) &&
GemmConfig::K_Warp_Tile_v == 128;
// When AQRowMajor is true for an 8-warp config, the kernel reads x_scale
// in row-major layout natively, avoiding the host-side transpose.
static constexpr bool aq_col_major = eight_waves && !GemmConfig::AQRowMajor_v;

using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile_v, GemmConfig::N_Tile_v, GemmConfig::K_Tile_v>,
Expand All @@ -145,7 +152,7 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args)
BLayout,
CLayout,
QuantMode,
std::conditional_t<eight_waves, AQLayout_8Warps, AQLayout>,
std::conditional_t<aq_col_major, AQLayout_8Warps, AQLayout>,
BQLayout,
transpose_c,
UseDoubleSmemBuffer>;
Expand Down Expand Up @@ -308,7 +315,15 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
const int N = WQ.size(0);
const int K = XQ.size(1);

const bool eight_waves =
// Whether this kernel configuration uses column-major AQ layout,
// requiring a host-side transpose of x_scale.
constexpr bool aq_col_major =
BQuantGroupSize::kN == 128 &&
(GemmInstance::M_Warp_v * GemmInstance::N_Warp_v * GemmInstance::K_Warp_v == 8) &&
GemmInstance::K_Warp_Tile_v == 128 &&
!GemmInstance::AQRowMajor_v;

constexpr bool eight_waves =
BQuantGroupSize::kN == 128 &&
(GemmInstance::M_Warp_v * GemmInstance::N_Warp_v * GemmInstance::K_Warp_v == 8) &&
GemmInstance::K_Warp_Tile_v == 128;
Expand All @@ -321,18 +336,34 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
// through the async kernel launch.
torch::Tensor x_scale_t;

if(eight_waves && !PreshuffleB)
if constexpr(aq_col_major)
{
x_scale_t = x_scale.transpose(0, 1).contiguous().view(x_scale.sizes());
args.aq_ptr = x_scale_t.data_ptr();
// 8-warp ColumnMajor AQ: transpose x_scale to col-major
if(!PreshuffleB)
{
x_scale_t = x_scale.transpose(0, 1).contiguous().view(x_scale.sizes());
args.aq_ptr = x_scale_t.data_ptr();
}
else
{
args.aq_ptr = x_scale.data_ptr();
}
}
else if(!eight_waves && PreshuffleB)
else if constexpr(!eight_waves)
{
x_scale_t = x_scale.view({x_scale.size(1), x_scale.size(0)}).transpose(0, 1).contiguous();
args.aq_ptr = x_scale_t.data_ptr();
if(PreshuffleB)
{
x_scale_t = x_scale.view({x_scale.size(1), x_scale.size(0)}).transpose(0, 1).contiguous();
args.aq_ptr = x_scale_t.data_ptr();
}
else
{
args.aq_ptr = x_scale.data_ptr();
}
}
else
{
// 8-warp RowMajor AQ: use x_scale directly, no transpose needed
args.aq_ptr = x_scale.data_ptr();
Comment thread
samremes marked this conversation as resolved.
}

Expand All @@ -357,7 +388,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
const int stride_A = XQ.stride(0);
const int stride_B = WQ.stride(0);
const int stride_C = Y.stride(0);
const int stride_AQ = eight_waves ? M : static_cast<int>(x_scale.stride(0));
const int stride_AQ = aq_col_major ? M : static_cast<int>(x_scale.stride(0));
const int stride_BQ = w_scale.stride(0);

args.QK_A = AQK;
Expand Down
Loading
Loading