Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,033 changes: 551 additions & 482 deletions aiter/configs/a8w8_tuned_gemm.csv

Large diffs are not rendered by default.

94 changes: 67 additions & 27 deletions aiter/configs/a8w8_untuned_gemm.csv
Original file line number Diff line number Diff line change
@@ -1,27 +1,67 @@
M,N,K
1, 1280, 8192
32, 1280, 8192
64, 1280, 8192
128, 1280, 8192
192, 1280, 8192
256, 1280, 8192
320, 1280, 8192
512, 1280, 8192
1024, 1280, 8192
2048, 1280, 8192
4096, 1280, 8192
8192, 1280, 8192
16384, 1280, 8192
1, 8192, 1024
32, 8192, 1024
64, 8192, 1024
128, 8192, 1024
192, 8192, 1024
256, 8192, 1024
320, 8192, 1024
512, 8192, 1024
1024, 8192, 1024
2048, 8192, 1024
4096, 8192, 1024
8192, 8192, 1024
16384, 8192, 1024
M,N,K,q_dtype_w
1, 1280, 8192,torch.int8
32, 1280, 8192,torch.int8
64, 1280, 8192,torch.int8
128, 1280, 8192,torch.int8
192, 1280, 8192,torch.int8
256, 1280, 8192,torch.int8
320, 1280, 8192,torch.int8
512, 1280, 8192,torch.int8
1024, 1280, 8192,torch.int8
2048, 1280, 8192,torch.int8
4096, 1280, 8192,torch.int8
8192, 1280, 8192,torch.int8
16384, 1280, 8192,torch.int8
1, 8192, 1024,torch.int8
32, 8192, 1024,torch.int8
64, 8192, 1024,torch.int8
128, 8192, 1024,torch.int8
192, 8192, 1024,torch.int8
256, 8192, 1024,torch.int8
320, 8192, 1024,torch.int8
512, 8192, 1024,torch.int8
1024, 8192, 1024,torch.int8
2048, 8192, 1024,torch.int8
4096, 8192, 1024,torch.int8
8192, 8192, 1024,torch.int8
16384, 8192, 1024,torch.int8
16,1024,8192,torch.float8_e4m3fn
16,1280,8192,torch.float8_e4m3fn
16,3584,8192,torch.float8_e4m3fn
16,7168,8192,torch.float8_e4m3fn
16,8192,8192,torch.float8_e4m3fn
16,10240,8192,torch.float8_e4m3fn
16,28672,8192,torch.float8_e4m3fn
16,57344,8192,torch.float8_e4m3fn
32,1024,8192,torch.float8_e4m3fn
32,1280,8192,torch.float8_e4m3fn
32,3584,8192,torch.float8_e4m3fn
32,7168,8192,torch.float8_e4m3fn
32,8192,8192,torch.float8_e4m3fn
32,10240,8192,torch.float8_e4m3fn
32,28672,8192,torch.float8_e4m3fn
32,57344,8192,torch.float8_e4m3fn
64,1024,8192,torch.float8_e4m3fn
64,1280,8192,torch.float8_e4m3fn
64,3584,8192,torch.float8_e4m3fn
64,7168,8192,torch.float8_e4m3fn
64,8192,8192,torch.float8_e4m3fn
64,10240,8192,torch.float8_e4m3fn
64,28672,8192,torch.float8_e4m3fn
64,57344,8192,torch.float8_e4m3fn
128,1024,8192,torch.float8_e4m3fn
128,1280,8192,torch.float8_e4m3fn
128,3584,8192,torch.float8_e4m3fn
128,7168,8192,torch.float8_e4m3fn
128,8192,8192,torch.float8_e4m3fn
128,10240,8192,torch.float8_e4m3fn
128,28672,8192,torch.float8_e4m3fn
128,57344,8192,torch.float8_e4m3fn
256,1024,8192,torch.float8_e4m3fn
256,1280,8192,torch.float8_e4m3fn
256,3584,8192,torch.float8_e4m3fn
256,7168,8192,torch.float8_e4m3fn
256,8192,8192,torch.float8_e4m3fn
256,10240,8192,torch.float8_e4m3fn
256,28672,8192,torch.float8_e4m3fn
256,57344,8192,torch.float8_e4m3fn
37 changes: 22 additions & 15 deletions aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

import torch
from torch import Tensor
Expand Down Expand Up @@ -295,37 +295,40 @@ def get_CKGEMM_config(M: int, N: int, K: int, tuned_file="a8w8_tuned_gemm.csv"):


@functools.lru_cache(maxsize=1024)
def get_bpreshuffle_GEMM_config(
def get_GEMM_config_with_quant_type(
M: int,
N: int,
K: int,
q_dtype_w: torch.dtype,
tuned_file=f"{AITER_ROOT_DIR}/aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv",
):
# Use dict to cache configs for different files
if not hasattr(get_bpreshuffle_GEMM_config, "file_cache"):
get_bpreshuffle_GEMM_config.file_cache = {}
if not hasattr(get_GEMM_config_with_quant_type, "file_cache"):
get_GEMM_config_with_quant_type.file_cache = {}

# Load file if not cached
if tuned_file not in get_bpreshuffle_GEMM_config.file_cache:
if tuned_file not in get_GEMM_config_with_quant_type.file_cache:
asmGemmDictDf = pd.read_csv(tuned_file).drop_duplicates()
get_bpreshuffle_GEMM_config.file_cache[tuned_file] = asmGemmDictDf.set_index(
["cu_num", "M", "N", "K", "q_dtype_w"]
).to_dict("index")
get_GEMM_config_with_quant_type.file_cache[tuned_file] = (
asmGemmDictDf.set_index(["cu_num", "M", "N", "K", "q_dtype_w"]).to_dict(
"index"
)
)

cu_num = get_cu_num()
padded_M = M
config = None
for gl in [None, 0, 1]:
padded_M = M if gl is None else get_padded_m(M, N, K, gl)
config = get_bpreshuffle_GEMM_config.file_cache[tuned_file].get(
config = get_GEMM_config_with_quant_type.file_cache[tuned_file].get(
(cu_num, padded_M, N, K, str(q_dtype_w)), None
)
if config is not None:
if AITER_LOG_TUNED_CONFIG:
logger.info(
f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned, in {tuned_file}, libtype is {config['libtype']}!"
)
msg = f"shape M:{M}, N:{N}, K:{K} q_dtype_w:{q_dtype_w}, found padded_M: {padded_M}, N:{N}, K:{K} is tuned, in {tuned_file}!"
if "libtype" in config:
msg += f" libtype is {config['libtype']}!"
logger.info(msg)
break
if config is None:
logger.info(
Expand Down Expand Up @@ -394,7 +397,7 @@ def gemm_a8w8_ASM(
x_scale.dtype == dtypes.fp32
and w_scale.dtype == dtypes.fp32
and (
asm_config := get_bpreshuffle_GEMM_config(
asm_config := get_GEMM_config_with_quant_type(
m,
n,
k,
Expand Down Expand Up @@ -434,7 +437,11 @@ def gemm_a8w8_CK(
m = XQ.shape[0]
n = WQ.shape[0]
k = XQ.shape[-1]
ck_config = get_CKGEMM_config(m, n, k, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE)

q_dtype_w = WQ.dtype if WQ.dtype in [dtypes.fp8, dtypes.i8] else dtypes.i8
ck_config = get_GEMM_config_with_quant_type(
m, n, k, q_dtype_w, AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE
)
if splitK is None:
if ck_config is not None:
splitK = ck_config["splitK"]
Expand Down Expand Up @@ -488,7 +495,7 @@ def gemm_a8w8_bpreshuffle(
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)

# CKTile only supports bf16 dtype
config = get_bpreshuffle_GEMM_config(
config = get_GEMM_config_with_quant_type(
m,
n,
k,
Expand Down
40 changes: 22 additions & 18 deletions csrc/ck_gemm_a8w8/gemm_a8w8_tune.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "gemm_a8w8_manifest.h"
#include "gemm_a8w8_lookup.h"
#include <string>
#include "py_itfs_common.h"

using RowwiseKernel = std::function<
torch::Tensor(torch::Tensor &, torch::Tensor &,
Expand Down Expand Up @@ -61,9 +62,9 @@ torch::Tensor gemm_a8w8_tune(
int kernelId,
int splitK)
{
TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(),
"Weights and activations should both be int8!");
TORCH_CHECK( x_scale.dtype() == w_scale.dtype(),
TORCH_CHECK(XQ.dtype() == WQ.dtype(),
"XQ and WQ should have the same dtype!");
TORCH_CHECK(x_scale.dtype() == w_scale.dtype(),
"Scales should have the same dtype!");
std::optional<torch::Tensor> bias = std::nullopt;

Expand All @@ -72,26 +73,29 @@ torch::Tensor gemm_a8w8_tune(
int K = XQ.size(1);
int KBatch = std::pow(2, splitK);

// if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::Half)
// {
// rowwise_dispatch<F32, F16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias);
// }
// else if (x_scale.dtype() == at::ScalarType::Float && Y.dtype() == at::ScalarType::BFloat16)
// {
// rowwise_dispatch<F32, B16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias);
// }
// else if (Y.dtype() == at::ScalarType::Half)
// {
// rowwise_dispatch<F16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias);
// }
// else
// Check if input is INT8 or FP8
bool is_i8 = (XQ.dtype() == at::ScalarType::Char);
bool is_fp8 = (XQ.dtype() == torch_fp8);

TORCH_CHECK(is_i8 || is_fp8,
"XQ dtype must be int8 or fp8, got: " + std::string(c10::toString(XQ.dtype())));

if (Y.dtype() == at::ScalarType::BFloat16)
{
rowwise_dispatch<I8, B16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch);
if (is_i8)
{
// INT8 path
rowwise_dispatch<I8, B16, B16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch);
}
else
{
// FP8 path
rowwise_dispatch<F8, F32, B16>(kernelId)(XQ, WQ, x_scale, w_scale, Y, bias, KBatch);
}
}
else
{
TORCH_CHECK(false, "Unsupported scales/output dtype!");
TORCH_CHECK(false, "Unsupported output dtype: " + std::string(c10::toString(Y.dtype())));
}
return Y;
}
Loading