Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

this_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, f"{this_dir}/utils/")
from chip_info import get_gfx
from chip_info import get_gfx, get_gfx_list
from cpp_extension import _jit_compile, get_hip_version
from file_baton import FileBaton
from torch_guard import torch_compile_guard # noqa: E402
Expand Down Expand Up @@ -264,9 +264,16 @@ def get_config_file(env_name, default_file, tuned_file_name):
sys.path.insert(0, AITER_META_DIR)
AITER_CSRC_DIR = f"{AITER_META_DIR}/csrc"
AITER_GRADLIB_DIR = f"{AITER_META_DIR}/gradlib"
gfx = get_gfx()
AITER_ASM_DIR = f"{AITER_META_DIR}/hsa/{gfx}/"
os.environ["AITER_ASM_DIR"] = AITER_ASM_DIR
gfx = get_gfx_list()
if len(gfx) == 1:
# single GPU arch
AITER_ASM_DIR = f"{AITER_META_DIR}/hsa/{gfx[0]}/"
os.environ["AITER_ASM_DIR"] = AITER_ASM_DIR
else:
# multiple GPU archs
AITER_ASM_DIR = [f"{AITER_META_DIR}/hsa/{g}/" for g in gfx]
os.environ["AITER_ASM_DIR"] = ":".join(AITER_ASM_DIR)

CK_3RDPARTY_DIR = os.environ.get(
"CK_DIR", f"{AITER_META_DIR}/3rdparty/composable_kernel"
)
Expand Down
12 changes: 6 additions & 6 deletions aiter/jit/optCompilerConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "f'{get_asm_dir()}/pa/codegen.py --output_dir {{}}'"
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m pa --output_dir {{}}'"
},
"module_pa": {
"srcs": [
Expand Down Expand Up @@ -317,7 +317,7 @@
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "f'{get_asm_dir()}/i8gemm/codegen.py --output_dir {{}}'"
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m i8gemm --output_dir {{}}'"
},
"module_gemm_a16w16_asm": {
"srcs": [
Expand All @@ -329,7 +329,7 @@
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "f'{get_asm_dir()}/bf16gemm/codegen.py --output_dir {{}}'"
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m bf16gemm --output_dir {{}}'"
},
"module_gemm_a4w4_asm": {
"srcs": [
Expand All @@ -341,7 +341,7 @@
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"blob_gen_cmd": "f'{get_asm_dir()}/f4gemm/codegen.py --output_dir {{}}'"
"blob_gen_cmd": "f'{AITER_META_DIR}/hsa/codegen.py -m f4gemm --output_dir {{}}'"
},
"module_gemm_a8w8_blockscale_asm": {
"srcs": [
Expand Down Expand Up @@ -386,8 +386,8 @@
],
"verbose": "False",
"blob_gen_cmd": [
"f'{get_asm_dir()}/fmoe_2stages/codegen.py --output_dir {{}}'",
"f'{get_asm_dir()}/fmoe/codegen.py --output_dir {{}}'"
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe_2stages --output_dir {{}}'",
"f'{AITER_META_DIR}/hsa/codegen.py -m fmoe --output_dir {{}}'"
]
},
"module_moe_ck2stages": {
Expand Down
2 changes: 2 additions & 0 deletions aiter/ops/gemm_op_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def gen_gemm_a16w16_asm_fake_tensors(
bias: Optional[Tensor] = None,
splitK: Optional[int] = None,
kernelName: Optional[str] = None,
bpreshuffle: bool = False,
) -> Tensor:
return out

Expand All @@ -39,6 +40,7 @@ def gemm_a16w16_asm(
bias: Optional[Tensor] = None,
splitK: Optional[int] = None,
kernelName: Optional[str] = None,
bpreshuffle: bool = False,
) -> Tensor: ...


Expand Down
3 changes: 2 additions & 1 deletion csrc/include/asm_gemm_a16w16.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
torch::Tensor& out, // Out:[M, N] f32
std::optional<torch::Tensor> bias,
std::optional<int> splitK,
std::optional<std::string> kernelName);
std::optional<std::string> kernelName,
bool bpreshuffle = false);
3 changes: 2 additions & 1 deletion csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ namespace py = pybind11;
py::arg("out"), \
py::arg("bias") = std::nullopt, \
py::arg("splitK") = std::nullopt, \
py::arg("kernelName") = std::nullopt);
py::arg("kernelName") = std::nullopt, \
py::arg("bpreshuffle") = false);

#define GEMM_A4W4_ASM_PYBIND \
m.def("gemm_a4w4_asm", \
Expand Down
7 changes: 5 additions & 2 deletions csrc/py_itfs_cu/asm_fmoe.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ FMoeKernel* get_heuristic_kernel(
uint32_t tg_num = 0;
uint32_t num_persistent_tgs = 0;
uint32_t round = 0xffffffff;
std::string selectedKl = kernel_name;
std::string arch_id = get_gpu_arch();
std::string selectedKl = kernel_name.empty() ? "" : arch_id + kernel_name;
int vskip = 1;
static std::unordered_map<std::string, std::unique_ptr<FMoeKernel>> impl_ptr_map;

Expand All @@ -271,6 +272,8 @@ FMoeKernel* get_heuristic_kernel(
{
for(const auto& el : *cfgs)
{
if (el.first.find(arch_id) != 0)
continue;
const auto& cfg = el.second;
if(cfg.vskip == vskip && cfg.smf == smf)
{
Expand Down Expand Up @@ -312,7 +315,7 @@ FMoeKernel* get_heuristic_kernel(
if(it != cfgs->end())
{
const auto& cfg = it->second;
const char* name = cfg.name.c_str();
const char* name = cfg.knl_name.c_str();
const char* co_name = cfg.co_name.c_str();
auto result = impl_ptr_map.emplace(name, nullptr);
if(cfg.ps == 1)
Expand Down
41 changes: 29 additions & 12 deletions csrc/py_itfs_cu/asm_gemm_a16w16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@ struct __attribute__((packed)) KernelArgs
unsigned int K;
p3 _p16;
unsigned int splitk;
p2 _p17;
p3 _p17;
unsigned int is_out_b16;
p3 _p18;
};

std::tuple<std::string, int>
get_heuristic_kernel(int M,
int N,
int K,
CFG* cfgs,
std::string arch_id,
bool bpreshuffle,
std::optional<int> splitk = std::nullopt,
std::optional<std::string> kernelName = std::nullopt)
{
Expand All @@ -75,10 +79,12 @@ get_heuristic_kernel(int M,

for(const auto& el : *cfgs)
{
if (el.first.find(arch_id) != 0)
continue;
const auto& cfg = el.second;
Comment on lines +82 to 84
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture filtering logic assumes that el.first (the kernel name) starts with the architecture ID. This implicit dependency is fragile and not documented. Consider adding a comment explaining this assumption, or better yet, use the cfg.arch field from the configuration struct for explicit architecture matching.

Suggested change
if (el.first.find(arch_id) != 0)
continue;
const auto& cfg = el.second;
const auto& cfg = el.second;
// Explicitly match architecture using cfg.arch instead of relying on kernel name format.
if (cfg.arch != arch_id)
continue;

Copilot uses AI. Check for mistakes.
if(kernelName.has_value() && kernelName.value() != el.first)
if(kernelName.has_value() && el.first != (arch_id + kernelName.value()))
continue;
if(N % cfg.tileN == 0)
if(N % cfg.tileN == 0 && cfg.bPreshuffle == (bpreshuffle ? 1 : 0))
{
// 1. select splitK
int split_K = 1;
Expand Down Expand Up @@ -125,6 +131,7 @@ get_heuristic_kernel(int M,
compute2mem_effi = local_compute2mem_effi;
oob = (M % cfg.tileM == 0) ? 0 : cfg.tileM - (M % cfg.tileM);
selectedKernelName = el.first;
// printf("Selected Kernel: %s\n", selectedKernelName.c_str());
selectedsplitK = split_K;
}
}
Expand All @@ -139,11 +146,13 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
torch::Tensor& out, // Out:[M, N] f32
std::optional<torch::Tensor> bias,
std::optional<int> splitK,
std::optional<std::string> kernelName)
std::optional<std::string> kernelName,
bool bpreshuffle = false)
{
TORCH_CHECK(out.dtype() == torch::ScalarType::Float,
"GEMM A16W16 asm only support Float32 output now!");

TORCH_CHECK(out.dtype() == torch::ScalarType::Float || out.dtype() == torch::ScalarType::BFloat16,
"GEMM A16W16 asm only support Float32 or Bf16 output now!");

std::string arch_id = get_gpu_arch();
// 1. prepare args
int Mdim = A.size(0);
int Ndim = B.size(0);
Expand All @@ -167,10 +176,14 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
int strideA1 = 0;
int strideB0 = 0;
int strideB1 = 0;
int is_out_b16 = 0;
// A row major, B col major, C row major
strideA0 = strideA1 = Kdim * 2; // in bytes
strideB0 = strideB1 = Kdim * 2;
strideC0 = strideC1 = strideD0 = strideD1 = Ndim * 4; // inbytes
const auto elem_bytes = out.element_size();
strideC0 = strideC1 = strideD0 = strideD1 = Ndim * elem_bytes; // inbytes
if (out.dtype() == torch::ScalarType::BFloat16)
is_out_b16 = 1;

szA += sz_A_pad;
szB += sz_B_pad;
Expand All @@ -191,6 +204,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
args.M = Mdim;
args.N = Ndim;
args.K = Kdim;
args.is_out_b16 = is_out_b16;

// args.stride_D0 = 25;
// args.stride_D1 = 80;
Expand All @@ -200,18 +214,20 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
// 2. select kl
static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;
AiterAsmKernel* impl_ptr = nullptr;
CFG* config_map = &cfg_bf16gemm_outf32;
CFG* config_map = &cfg_bf16gemm_fp32bf16;

// 2.1 static dict
std::string selectedKernelName = kernelName.value_or("");
std::string selectedKernelName = kernelName.has_value() ? arch_id + kernelName.value() : "";
int selectedksplit = splitK.value_or(0) ?: 1;
if(!kernelName.has_value() || kernelName == "")
if(!kernelName.has_value() || kernelName == "" || !splitK.has_value())
{

auto it_sel = get_heuristic_kernel(Mdim,
Ndim,
Kdim,
config_map,
arch_id,
bpreshuffle,
splitK.has_value() ? splitK : std::nullopt,
kernelName.has_value() ? kernelName : std::nullopt);
selectedKernelName = std::get<0>(it_sel);
Expand All @@ -237,13 +253,14 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
// printf("N: %u\n", args.N);
// printf("K: %u\n", args.K);
// printf("splitk: %u\n", args.splitk);
// printf("is_out_b16: %u\n", args.is_out_b16);
// printf("=======================================\n");

auto it_kl = config_map->find(selectedKernelName);
if(it_kl != config_map->end())
{
const auto& cfg = it_kl->second;
const char* name = cfg.name.c_str();
const char* name = cfg.knl_name.c_str();
const char* co_name = cfg.co_name.c_str();
SUBM = cfg.tileM;
SUBN = cfg.tileN;
Expand Down
21 changes: 14 additions & 7 deletions csrc/py_itfs_cu/asm_gemm_a4w4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ static CFG* get_cfg(torch::Tensor& inp, torch::Tensor& out)
{

#if defined(__Float4_e2m1fn_x2)
if(inp.dtype() == torch_fp4x2 &&
out.scalar_type() == at::ScalarType::BFloat16)
if(inp.dtype() == torch_fp4x2 && out.scalar_type() == at::ScalarType::BFloat16)
#else
if((inp.dtype() == torch::kUInt8) && out.scalar_type() == at::ScalarType::BFloat16)
#endif
Expand All @@ -87,6 +86,7 @@ static CFG* get_cfg(torch::Tensor& inp, torch::Tensor& out)
std::tuple<std::string, int> get_heuristic_kernel(int M,
int N,
int K,
std::string arch_id,
std::optional<int> log2_k_split,
std::optional<bool> bpreshuffle,
CFG* cfgs)
Expand All @@ -107,6 +107,8 @@ std::tuple<std::string, int> get_heuristic_kernel(int M,

for(const auto& el : *cfgs)
{
if(el.first.find(arch_id) != 0)
continue;
const auto& cfg = el.second;
if(cfg.bpreshuffle == bpreshuffle_en &&
((cfg.splitK == log2_k_split_en) || !log2_k_split.has_value()))
Expand Down Expand Up @@ -197,8 +199,8 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2

const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(A));
const hipStream_t stream = at::hip::getCurrentHIPStream();
CFG* config_map = get_cfg(A, out);
using DictKey = std::tuple<int, int, int, std::optional<int>, std::optional<bool>>;
CFG* config_map = get_cfg(A, out);
using DictKey = std::tuple<int, int, int, std::optional<int>, std::optional<bool>>;
struct SimpleHash
{
size_t operator()(const DictKey& key) const
Expand All @@ -220,6 +222,9 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2

static std::unordered_map<std::string, std::unique_ptr<AiterAsmKernel>> impl_ptr_map;

std::string arch_id = get_gpu_arch();
kernelName = kernelName.empty() ? "" : arch_id + kernelName;

int selectedksplit = log2_k_split.has_value() ? log2_k_split.value() : 0;
if(kernelName.empty())
{
Expand All @@ -232,7 +237,8 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2
}
else
{
auto it = get_heuristic_kernel(Mdim, Ndim, Kdim, log2_k_split, bpreshuffle, config_map);
auto it = get_heuristic_kernel(
Mdim, Ndim, Kdim, arch_id, log2_k_split, bpreshuffle, config_map);

kernelName = std::get<0>(it);
selectedksplit = std::get<1>(it);
Expand All @@ -250,7 +256,7 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2
if(it != config_map->end())
{
const auto& cfg = it->second;
const char* name = cfg.name.c_str();
const char* name = cfg.knl_name.c_str();
const char* co_name = cfg.co_name.c_str();
SUBM = cfg.tile_M;
SUBN = cfg.tile_N;
Expand All @@ -260,7 +266,8 @@ torch::Tensor gemm_a4w4_asm(torch::Tensor& A, // A:[M, K/2] f4x2
args.log2_k_split = selectedksplit;
int k_num = 1 << args.log2_k_split;
TORCH_CHECK(Kdim % k_num == 0, __func__, " Kdim % (1 << args.log2_k_split) != 0 !");
if(k_num>1)out.zero_();
if(k_num > 1)
out.zero_();
int k_per_tg = Kdim / k_num;
k_per_tg = ((k_per_tg + 256 - 1) / 256) * 256;
gdz = (Kdim + k_per_tg - 1) / k_per_tg;
Expand Down
Loading
Loading