Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def gemm_a8w8_blockscale_ck(
x_scale: torch.Tensor,
w_scale: torch.Tensor,
Out: torch.Tensor,
splitK: int = 0,
) -> torch.Tensor: ...


Expand All @@ -231,6 +232,7 @@ def gemm_a8w8_blockscale_cktile(
w_scale: torch.Tensor,
Out: torch.Tensor,
isBpreshuffled: bool = False,
splitK: int = 0,
) -> torch.Tensor: ...


Expand Down Expand Up @@ -685,10 +687,15 @@ def gemm_a8w8_blockscale(
)
if config is not None:
libtype = config["libtype"]
splitK = int(config.get("splitK", 0))
if libtype == "ck":
return gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Y)
return gemm_a8w8_blockscale_ck(
XQ, WQ, x_scale, w_scale, Y, splitK=splitK
)
elif libtype == "cktile":
return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y)
return gemm_a8w8_blockscale_cktile(
XQ, WQ, x_scale, w_scale, Y, splitK=splitK
)
Comment thread
samremes marked this conversation as resolved.
else:
assert 0, f"Unsupported libtype {libtype} for gemm_a8w8_blockscale"
try:
Expand Down
21 changes: 13 additions & 8 deletions csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <cmath>
#include <functional>
#include <unordered_map>

Expand All @@ -15,7 +14,7 @@
#include "gemm_a8w8_blockscale_manifest.h"

using BlockwiseKernel = std::function<torch::Tensor(
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&)>;
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>;

using BlockwiseKernelMap = GemmDispatchMap<BlockwiseKernel>;

Expand Down Expand Up @@ -83,22 +82,28 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ,
torch::Tensor& WQ,
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y)
torch::Tensor& Y,
int splitK)
{
TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!");
TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!");

int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
TORCH_CHECK(splitK >= 0 && splitK <= 30,
"splitK must be in the range [0, 30], got ",
splitK);

int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
int KBatch = 1 << splitK;

if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half)
{
blockscale_dispatch<FP32, FP16>(M, N, K)(XQ, WQ, x_scale, w_scale, Y);
blockscale_dispatch<FP32, FP16>(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch);
}
else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16)
{
blockscale_dispatch<FP32, BF16>(M, N, K)(XQ, WQ, x_scale, w_scale, Y);
blockscale_dispatch<FP32, BF16>(M, N, K)(XQ, WQ, x_scale, w_scale, Y, KBatch);
}
else
{
Expand Down
23 changes: 14 additions & 9 deletions csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_cktile.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#include <cmath>
#include <functional>
#include <unordered_map>

Expand All @@ -15,7 +14,7 @@
#include "gemm_a8w8_blockscale_cktile_manifest.h"

using BlockwiseKernel = std::function<torch::Tensor(
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool)>;
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>;

using BlockwiseKernelMap = GemmDispatchMap<BlockwiseKernel>;

Expand Down Expand Up @@ -83,24 +82,30 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ,
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y,
bool preshuffleB)
bool preshuffleB,
int splitK)
{
TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!");
TORCH_CHECK(x_scale.dtype() == w_scale.dtype(), "Scales should have the same dtype!");

int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
TORCH_CHECK(splitK >= 0 && splitK <= 30,
"splitK must be in the range [0, 30], got ",
splitK);

int M = XQ.size(0);
int N = WQ.size(0);
int K = XQ.size(1);
int KBatch = 1 << splitK;

if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half)
{
blockscale_dispatch<TILE_FP32, TILE_FP16>(M, N, K)(
XQ, WQ, x_scale, w_scale, Y, preshuffleB);
XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch);
}
else if(x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16)
{
blockscale_dispatch<TILE_FP32, TILE_BF16>(M, N, K)(
XQ, WQ, x_scale, w_scale, Y, preshuffleB);
XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch);
}
else
{
Expand All @@ -116,5 +121,5 @@ torch::Tensor gemm_a8w8_blockscale_bpreshuffle_cktile(torch::Tensor& XQ,
torch::Tensor& Y,
bool preshuffleB)
{
return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB);
return gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Y, preshuffleB, 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "gemm_a8w8_blockscale_cktile_manifest.h"

using BlockwiseKernel = std::function<torch::Tensor(
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool)>;
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, bool, int)>;

// For certain high priority shapes, we directly use the best kernel rather
// than use heuristics.
Expand Down Expand Up @@ -73,12 +73,12 @@ torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ,
if(Y.dtype() == at::ScalarType::BFloat16)
{
blockwise_dispatch_cktile<TILE_FP32, TILE_BF16>(kernelId)(
XQ, WQ, x_scale, w_scale, Y, preshuffleB);
XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch);
}
else if(Y.dtype() == at::ScalarType::Half)
{
blockwise_dispatch_cktile<TILE_FP32, TILE_FP16>(kernelId)(
XQ, WQ, x_scale, w_scale, Y, preshuffleB);
XQ, WQ, x_scale, w_scale, Y, preshuffleB, KBatch);
}
else
{
Expand Down
6 changes: 3 additions & 3 deletions csrc/ck_gemm_a8w8_blockscale/gemm_a8w8_blockscale_tune.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "gemm_a8w8_blockscale_manifest.h"

using BlockwiseKernel = std::function<torch::Tensor(
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&)>;
torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, torch::Tensor&, int)>;

// For certain high priority shapes, we directly use the best kernel rather
// than use heuristics.
Expand Down Expand Up @@ -71,11 +71,11 @@ torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ,

if(Y.dtype() == at::ScalarType::BFloat16)
{
blockwise_dispatch<FP32, BF16>(kernelId)(XQ, WQ, x_scale, w_scale, Y);
blockwise_dispatch<FP32, BF16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch);
}
else if(Y.dtype() == at::ScalarType::Half)
{
blockwise_dispatch<FP32, FP16>(kernelId)(XQ, WQ, x_scale, w_scale, Y);
blockwise_dispatch<FP32, FP16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, KBatch);
}
else
{
Expand Down
11 changes: 7 additions & 4 deletions csrc/ck_gemm_a8w8_blockscale/gen_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def gen_ck_instance(self, k: KernelInstance):
torch::Tensor &WQ,
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y
torch::Tensor &Y,
int KBatch
)
{{
// Get M, N, K from input tensors.
Expand Down Expand Up @@ -186,7 +187,7 @@ def gen_ck_instance(self, k: KernelInstance):
ck::tensor_operation::device::GemmSpecialization::{{GemmSpec}}>;

// Run kernel instance.
return gemm_a8w8_blockscale_impl<DDataType, EDataType, LegacyGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
return gemm_a8w8_blockscale_impl<DDataType, EDataType, LegacyGemmInstance>(XQ, WQ, x_scale, w_scale, Y, KBatch);
"""
INSTANCE_IMPL_str = (
LEGACY_INSTANCE_IMPL.replace(
Expand Down Expand Up @@ -238,7 +239,8 @@ def gen_ck_instance(self, k: KernelInstance):
torch::Tensor &WQ,
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y
torch::Tensor &Y,
int KBatch
);

"""
Expand Down Expand Up @@ -312,7 +314,8 @@ def gen_manifest_head(self, kernels_dict):
torch::Tensor &WQ,
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y);
torch::Tensor &Y,
int KBatch);
"""
MAINFEST_end = """

Expand Down
11 changes: 7 additions & 4 deletions csrc/ck_gemm_a8w8_blockscale/gen_instances_cktile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def gen_cktile_instance(self, k: TileKernelInstance):
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y,
bool preshuffleB
bool preshuffleB,
int k_batch
)
{{
// Get M, N, K from input tensors.
Expand All @@ -100,7 +101,7 @@ def gen_cktile_instance(self, k: TileKernelInstance):
{str(k.AQRowMajor).lower()}>;

// Run kernel instance.
return gemm_a8w8_blockscale_cktile_impl<DDataType, EDataType, TileGemmInstance>(XQ, WQ, x_scale, w_scale, Y, preshuffleB);
return gemm_a8w8_blockscale_cktile_impl<DDataType, EDataType, TileGemmInstance>(XQ, WQ, x_scale, w_scale, Y, preshuffleB, k_batch);
"""

TILE_INSTANCE_IMPL_str = TILE_INSTANCE_IMPL.replace(
Expand All @@ -123,7 +124,8 @@ def gen_cktile_instance(self, k: TileKernelInstance):
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y,
bool preshuffleB
bool preshuffleB,
int k_batch
);

"""
Expand Down Expand Up @@ -198,7 +200,8 @@ def gen_manifest_head(self, kernels_dict):
torch::Tensor &x_scale,
torch::Tensor &w_scale,
torch::Tensor &Y,
bool preshuffleB);
bool preshuffleB,
int k_batch);
"""
MAINFEST_end = """

Expand Down
3 changes: 2 additions & 1 deletion csrc/ck_gemm_a8w8_blockscale/include/gemm_a8w8_blockscale.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ torch::Tensor gemm_a8w8_blockscale(torch::Tensor& XQ,
torch::Tensor& WQ,
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y);
torch::Tensor& Y,
int splitK = 0);

torch::Tensor gemm_a8w8_blockscale_tune(torch::Tensor& XQ,
torch::Tensor& WQ,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ torch::Tensor gemm_a8w8_blockscale_cktile(torch::Tensor& XQ,
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y,
bool preshuffleB);
bool preshuffleB,
int splitK = 0);

torch::Tensor gemm_a8w8_blockscale_cktile_tune(torch::Tensor& XQ,
torch::Tensor& WQ,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,6 @@ void TileGemmComputeImpl(ck_tile::QuantGemmHostArgs& args)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();

if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}

if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
Expand Down Expand Up @@ -283,7 +278,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y,
bool PreshuffleB)
bool PreshuffleB,
int k_batch = 1)
{
// check
TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!");
Expand Down Expand Up @@ -372,8 +368,7 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
args.bq_ptr = w_scale.data_ptr();
args.c_ptr = Y.data_ptr();

// split-k is not supported yet for tile quant gemm, set k_batch to 1
args.k_batch = 1;
args.k_batch = k_batch;
args.M = M;
args.N = N;
args.K = K;
Expand All @@ -400,6 +395,14 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_cktile_impl(torch::Tensor& XQ
args.stride_AQ = stride_AQ;
args.stride_BQ = stride_BQ;

// Split-K uses atomic_add into C; zero the output buffer first.
// Use zero_() so all rows are cleared regardless of the leading-dimension
// stride (e.g. padded tensors produced by vLLM's _maybe_pad_fp8_weight).
if(k_batch > 1)
{
Y.zero_();
}

// do tile GEMM
if(PreshuffleB)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ,
torch::Tensor& WQ,
torch::Tensor& x_scale,
torch::Tensor& w_scale,
torch::Tensor& Y)
torch::Tensor& Y,
int KBatch = 1)
{
int M = XQ.size(0);
int N = WQ.size(0);
Expand Down Expand Up @@ -160,6 +161,13 @@ __forceinline__ torch::Tensor gemm_a8w8_blockscale_impl(torch::Tensor& XQ,
b_element_op,
cde_element_op);

TORCH_CHECK(KBatch >= 1, "KBatch must be >= 1, got ", KBatch);

if(KBatch > 1)
{
device_gemm.SetKBatch(&argument, KBatch);
}

TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!");

invoker.Run(argument, StreamConfig{at::hip::getCurrentHIPStream()});
Expand Down
6 changes: 4 additions & 2 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ namespace py = pybind11;
py::arg("WQ"), \
py::arg("x_scale"), \
py::arg("w_scale"), \
py::arg("Out"));
py::arg("Out"), \
py::arg("splitK") = 0);

#define GEMM_A8W8_BLOCKSCALE_TUNE_PYBIND \
m.def("gemm_a8w8_blockscale_tune", \
Expand All @@ -520,7 +521,8 @@ namespace py = pybind11;
py::arg("x_scale"), \
py::arg("w_scale"), \
py::arg("Out"), \
py::arg("preshuffleB") = false);
py::arg("preshuffleB") = false, \
py::arg("splitK") = 0);

#define GEMM_A8W8_BLOCKSCALE_CKTILE_TUNE_PYBIND \
m.def("gemm_a8w8_blockscale_cktile_tune", \
Expand Down
Loading
Loading