Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
958 changes: 479 additions & 479 deletions aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv

Large diffs are not rendered by default.

28 changes: 21 additions & 7 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,18 +1092,29 @@ def _ensure_loaded():
)
lib = ctypes.CDLL(so_path)
c_func = getattr(lib, fc_name)
c_func.restype = None

hints = typing.get_type_hints(func)

ret_hint = hints.get("return")
if ret_hint is int:
c_func.restype = ctypes.c_int
elif ret_hint is float:
c_func.restype = ctypes.c_float
else:
c_func.restype = None

argtypes = []
has_tensor = False
for pname in inspect.signature(func).parameters:
hint = hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)
if hint is torch.Tensor:
argtypes.append(ctypes.POINTER(aiter_tensor_t))
has_tensor = True
elif _is_union(origin) and torch.Tensor in type_args:
argtypes.append(ctypes.POINTER(aiter_tensor_t))
has_tensor = True
elif _is_union(origin) and int in type_args:
argtypes.append(ctypes.c_int)
elif _is_union(origin) and str in type_args:
Expand All @@ -1118,11 +1129,13 @@ def _ensure_loaded():
argtypes.append(ctypes.c_float)
else:
argtypes.append(ctypes.c_void_p)
argtypes.append(ctypes.c_void_p) # hipStream_t
if has_tensor:
argtypes.append(ctypes.c_void_p) # hipStream_t
c_func.argtypes = argtypes

_cache["lib"] = lib
_cache["c_func"] = c_func
_cache["has_tensor"] = has_tensor

def _check_args_before_convert(bound_args, hints):
for pname, value in bound_args.items():
Expand Down Expand Up @@ -1232,10 +1245,11 @@ def caller(*args, **kwargs):
else:
c_args.append(value)

c_args.append(
ctypes.c_void_p(torch.cuda.current_stream(tensor_device).cuda_stream)
)
c_func(*c_args)
if _cache.get("has_tensor"):
c_args.append(
ctypes.c_void_p(torch.cuda.current_stream(tensor_device).cuda_stream)
)
return c_func(*c_args)

return caller

Expand All @@ -1255,7 +1269,7 @@ def decorator(func):

@functools.wraps(func)
def ctypes_wrapper(*args, **kwargs):
ctypes_caller(*args, **kwargs)
return ctypes_caller(*args, **kwargs)

@torch_compile_guard(device="cuda", calling_func_=func)
def ctypes_custom_wrapper(*args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion aiter/jit/optCompilerConfig.json
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@
},
"module_gemm_common": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/gemm_common_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'"
],
"flags_extra_cc": [],
Expand Down
8 changes: 3 additions & 5 deletions aiter/ops/gemm_op_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from ..jit.core import (
compile_ops,
)
from ..jit.core import compile_ops


@compile_ops("module_gemm_common")
@compile_ops("module_gemm_common", fc_name="getPaddedM", ffi_type="ctypes")
def get_padded_m(M: int, N: int, K: int, gl: int) -> int: ...
1 change: 1 addition & 0 deletions csrc/include/aiter_enum.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
// Single source of truth: aiter/ops/enum.py parses enums from this file
Comment thread
yzhou103 marked this conversation as resolved.
#include <string>


Expand Down
11 changes: 9 additions & 2 deletions csrc/include/gemm_common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#ifdef __cplusplus
extern "C" {
#endif

int getPaddedM(int M, int N, int K, int gl /*granularity level*/);

#ifdef __cplusplus
}
#endif
21 changes: 0 additions & 21 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1563,27 +1563,6 @@ namespace py = pybind11;
m.def("rocb_mm", &RocSolIdxBlas, "mm"); \
m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define AITER_ENUM_PYBIND \
pybind11::enum_<QuantType>(m, "QuantType") \
.value("No", QuantType::No) \
.value("per_Tensor", QuantType::per_Tensor) \
.value("per_Token", QuantType::per_Token) \
.value("per_1x32", QuantType::per_1x32) \
.value("per_1x128", QuantType::per_1x128) \
.value("per_128x128", QuantType::per_128x128) \
.value("per_256x128", QuantType::per_256x128) \
.value("per_1024x128", QuantType::per_1024x128) \
.export_values(); \
pybind11::enum_<ActivationType>(m, "ActivationType") \
.value("No", ActivationType::No) \
.value("Silu", ActivationType::Silu) \
.value("Gelu", ActivationType::Gelu) \
.value("Swiglu", ActivationType::Swiglu) \
.export_values(); \
pybind11::implicitly_convertible<int, QuantType>(); \
pybind11::implicitly_convertible<int, ActivationType>();
#define GEMM_COMMON_PYBIND \
m.def("get_padded_m", &getPaddedM, py::arg("M"), py::arg("N"), py::arg("K"), py::arg("gl"));

#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
Expand Down
20 changes: 10 additions & 10 deletions csrc/py_itfs_cu/gemm_common.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#include <cmath>
#include "gemm_common.h"
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#include <climits>

static constexpr int nextPow2(unsigned int num)
{
Expand All @@ -9,33 +9,33 @@ static constexpr int nextPow2(unsigned int num)
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

extern "C" __attribute__((visibility("default")))
int getPaddedM(int M, int N, int K, int gl) {
int padded_m = M;
// granularity level, gl = 0, Fine-grained search
if (gl == 0) {
if(M <= 256)
{
padded_m = (M + 15) / 16 * 16; // Round up to the next multiple of 16
padded_m = (M + 15) / 16 * 16;
}
else if(M <= 1024)
{
padded_m = (M + 31) / 32 * 32; // Round up to the next multiple of 32
padded_m = (M + 31) / 32 * 32;
}
else if(M <= 4096)
{
padded_m = (M + 63) / 64 * 64; // Round up to the next multiple of 64
padded_m = (M + 63) / 64 * 64;
}
else
{
padded_m = (M + 127) / 128 * 128; // Round up to the next multiple of 128
padded_m = (M + 127) / 128 * 128;
}
} else if (gl == 1) {
if (M > 8192 && N > 4096) {
padded_m = 8192;
} else {
padded_m = nextPow2(M);
}
}
}
return padded_m;

}
}
22 changes: 19 additions & 3 deletions csrc/pybind/aiter_enum_pybind.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#include <pybind11/pybind11.h>
#include "rocm_ops.hpp"
#include "aiter_enum.h"

PYBIND11_MODULE(module_aiter_enum, m)
{
AITER_ENUM_PYBIND;
pybind11::enum_<QuantType>(m, "QuantType")
.value("No", QuantType::No)
.value("per_Tensor", QuantType::per_Tensor)
.value("per_Token", QuantType::per_Token)
.value("per_1x32", QuantType::per_1x32)
.value("per_1x128", QuantType::per_1x128)
.value("per_128x128", QuantType::per_128x128)
.value("per_256x128", QuantType::per_256x128)
.value("per_1024x128", QuantType::per_1024x128)
.export_values();
pybind11::enum_<ActivationType>(m, "ActivationType")
.value("No", ActivationType::No)
.value("Silu", ActivationType::Silu)
.value("Gelu", ActivationType::Gelu)
.value("Swiglu", ActivationType::Swiglu)
.export_values();
pybind11::implicitly_convertible<int, QuantType>();
pybind11::implicitly_convertible<int, ActivationType>();
}
9 changes: 0 additions & 9 deletions csrc/pybind/gemm_common_pybind.cu

This file was deleted.

22 changes: 7 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,32 +149,24 @@ def get_exclude_ops():

for module in all_modules:
if PREBUILD_KERNELS == 1:
if "_tune" in module or module == "module_gemm_mi350_a8w8_blockscale_asm":
if "_tune" in module:
exclude_ops.append(module)
if "mha" in module and module not in [
"module_fmha_v3_fwd",
"module_fmha_v3_varlen_fwd",
]:
exclude_ops.append(module)
elif PREBUILD_KERNELS == 2:
# Exclude _bwd, _tune, and specific module
if (
"_bwd" in module
or "_tune" in module
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
# Exclude _bwd and _tune
if "_bwd" in module or "_tune" in module:
exclude_ops.append(module)
elif PREBUILD_KERNELS == 3:
# Keep only module_fmha_v3* and module_aiter_enum
if not (
module.startswith("module_fmha_v3")
or module == "module_aiter_enum"
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
# Keep only module_fmha_v3*
if not module.startswith("module_fmha_v3"):
exclude_ops.append(module)
else:
# Default behavior: exclude tunes and specific mi350 module
if "_tune" in module or module == "module_gemm_mi350_a8w8_blockscale_asm":
# Default behavior: exclude tunes
if "_tune" in module:
exclude_ops.append(module)

return exclude_ops
Expand Down
Loading