Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,028 changes: 514 additions & 514 deletions aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv

Large diffs are not rendered by default.

28 changes: 21 additions & 7 deletions aiter/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,18 +1067,29 @@ def _ensure_loaded():
)
lib = ctypes.CDLL(so_path)
c_func = getattr(lib, fc_name)
c_func.restype = None

hints = typing.get_type_hints(func)

ret_hint = hints.get("return")
if ret_hint is int:
c_func.restype = ctypes.c_int
elif ret_hint is float:
c_func.restype = ctypes.c_float
else:
c_func.restype = None

argtypes = []
has_tensor = False
for pname in inspect.signature(func).parameters:
hint = hints.get(pname)
origin = typing.get_origin(hint)
type_args = typing.get_args(hint)
if hint is torch.Tensor:
argtypes.append(ctypes.POINTER(AiterTensor))
has_tensor = True
elif _is_union(origin) and torch.Tensor in type_args:
argtypes.append(ctypes.POINTER(AiterTensor))
has_tensor = True
elif _is_union(origin) and int in type_args:
argtypes.append(ctypes.c_int)
elif _is_union(origin) and str in type_args:
Expand All @@ -1091,11 +1102,13 @@ def _ensure_loaded():
argtypes.append(ctypes.c_float)
else:
argtypes.append(ctypes.c_void_p)
argtypes.append(ctypes.c_void_p) # hipStream_t
if has_tensor:
argtypes.append(ctypes.c_void_p) # hipStream_t
c_func.argtypes = argtypes

_cache["lib"] = lib
_cache["c_func"] = c_func
_cache["has_tensor"] = has_tensor

def _check_args_before_convert(bound_args, hints):
for pname, value in bound_args.items():
Expand Down Expand Up @@ -1197,10 +1210,11 @@ def caller(*args, **kwargs):
else:
c_args.append(value)

c_args.append(
ctypes.c_void_p(torch.cuda.current_stream(tensor_device).cuda_stream)
)
c_func(*c_args)
if _cache.get("has_tensor"):
c_args.append(
ctypes.c_void_p(torch.cuda.current_stream(tensor_device).cuda_stream)
)
return c_func(*c_args)

return caller

Expand All @@ -1220,7 +1234,7 @@ def decorator(func):

@functools.wraps(func)
def ctypes_wrapper(*args, **kwargs):
ctypes_caller(*args, **kwargs)
return ctypes_caller(*args, **kwargs)

@torch_compile_guard(device="cuda", calling_func_=func)
def ctypes_custom_wrapper(*args, **kwargs):
Expand Down
13 changes: 0 additions & 13 deletions aiter/jit/optCompilerConfig.json
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
{
"module_aiter_enum": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/aiter_enum_pybind.cu'"
],
"flags_extra_cc": [],
"flags_extra_hip": [],
"extra_ldflags": "None",
"extra_include": [],
"verbose": "False",
"torch_exclude": "True",
"blob_gen_cmd": "''"
},
"module_activation": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/activation_pybind.cu'",
Expand Down Expand Up @@ -198,7 +186,6 @@
},
"module_gemm_common": {
"srcs": [
"f'{AITER_CSRC_DIR}/pybind/gemm_common_pybind.cu'",
"f'{AITER_CSRC_DIR}/py_itfs_cu/gemm_common.cu'"
],
"flags_extra_cc": [],
Expand Down
29 changes: 19 additions & 10 deletions aiter/ops/enum.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from ..jit.core import compile_ops

# from enum import Enum as Enum
Enum = int
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

# Mirror of csrc/include/aiter_enum.h -- update both when changing enum values
from enum import IntEnum

@compile_ops("module_aiter_enum", "ActivationType")
def _ActivationType(dummy): ...
Enum = int


@compile_ops("module_aiter_enum", "QuantType")
def _QuantType(dummy): ...
class ActivationType(IntEnum):
No = -1
Silu = 0
Gelu = 1
Swiglu = 2


ActivationType = type(_ActivationType(0))
QuantType = type(_QuantType(0))
class QuantType(IntEnum):
No = 0
per_Tensor = 1
per_Token = 2
per_1x32 = 3
per_1x128 = 4
per_128x128 = 5
per_256x128 = 6
per_1024x128 = 7
8 changes: 3 additions & 5 deletions aiter/ops/gemm_op_common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

from ..jit.core import (
compile_ops,
)
from ..jit.core import compile_ops


@compile_ops("module_gemm_common")
@compile_ops("module_gemm_common", fc_name="getPaddedM", ffi_type="ctypes")
def get_padded_m(M: int, N: int, K: int, gl: int) -> int: ...
1 change: 1 addition & 0 deletions csrc/include/aiter_enum.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
// Single source of truth: aiter/ops/enum.py parses enums from this file
Comment thread
yzhou103 marked this conversation as resolved.
#include <string>


Expand Down
11 changes: 9 additions & 2 deletions csrc/include/gemm_common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <torch/extension.h>
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.

#ifdef __cplusplus
extern "C" {
#endif

int getPaddedM(int M, int N, int K, int gl /*granularity level*/);

#ifdef __cplusplus
}
#endif
21 changes: 0 additions & 21 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,27 +1673,6 @@ namespace py = pybind11;
m.def("rocb_mm", &RocSolIdxBlas, "mm"); \
m.def("rocb_findallsols", &RocFindAllSolIdxBlas, "rocblas_find_all_sols");

#define AITER_ENUM_PYBIND \
pybind11::enum_<QuantType>(m, "QuantType") \
.value("No", QuantType::No) \
.value("per_Tensor", QuantType::per_Tensor) \
.value("per_Token", QuantType::per_Token) \
.value("per_1x32", QuantType::per_1x32) \
.value("per_1x128", QuantType::per_1x128) \
.value("per_128x128", QuantType::per_128x128) \
.value("per_256x128", QuantType::per_256x128) \
.value("per_1024x128", QuantType::per_1024x128) \
.export_values(); \
pybind11::enum_<ActivationType>(m, "ActivationType") \
.value("No", ActivationType::No) \
.value("Silu", ActivationType::Silu) \
.value("Gelu", ActivationType::Gelu) \
.value("Swiglu", ActivationType::Swiglu) \
.export_values(); \
pybind11::implicitly_convertible<int, QuantType>(); \
pybind11::implicitly_convertible<int, ActivationType>();
#define GEMM_COMMON_PYBIND \
m.def("get_padded_m", &getPaddedM, py::arg("M"), py::arg("N"), py::arg("K"), py::arg("gl"));

#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
Expand Down
20 changes: 10 additions & 10 deletions csrc/py_itfs_cu/gemm_common.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#include <cmath>
#include "gemm_common.h"
// SPDX-License-Identifier: MIT
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#include <climits>

static constexpr int nextPow2(unsigned int num)
{
Expand All @@ -9,33 +9,33 @@ static constexpr int nextPow2(unsigned int num)
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

extern "C" __attribute__((visibility("default")))
int getPaddedM(int M, int N, int K, int gl) {
int padded_m = M;
// granularity level, gl = 0, Fine-grained search
if (gl == 0) {
if(M <= 256)
{
padded_m = (M + 15) / 16 * 16; // Round up to the next multiple of 16
padded_m = (M + 15) / 16 * 16;
}
else if(M <= 1024)
{
padded_m = (M + 31) / 32 * 32; // Round up to the next multiple of 32
padded_m = (M + 31) / 32 * 32;
}
else if(M <= 4096)
{
padded_m = (M + 63) / 64 * 64; // Round up to the next multiple of 64
padded_m = (M + 63) / 64 * 64;
}
else
{
padded_m = (M + 127) / 128 * 128; // Round up to the next multiple of 128
padded_m = (M + 127) / 128 * 128;
}
} else if (gl == 1) {
if (M > 8192 && N > 4096) {
padded_m = 8192;
} else {
padded_m = nextPow2(M);
}
}
}
return padded_m;

}
}
10 changes: 0 additions & 10 deletions csrc/pybind/aiter_enum_pybind.cu

This file was deleted.

9 changes: 0 additions & 9 deletions csrc/pybind/gemm_common_pybind.cu

This file was deleted.

3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ def get_exclude_ops():
):
exclude_ops.append(module)
elif PREBUILD_KERNELS == 3:
# Keep only module_fmha_v3* and module_aiter_enum
# Keep only module_fmha_v3*
if not (
module.startswith("module_fmha_v3")
or module == "module_aiter_enum"
or module == "module_gemm_mi350_a8w8_blockscale_asm"
):
Comment thread
yzhou103 marked this conversation as resolved.
Outdated
exclude_ops.append(module)
Expand Down