Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion aiter/ops/triton/_triton_kernels/gemm_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@
from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd
from ..utils._triton import arch_info
from ..utils.core import AITER_TRITON_CONFIGS_PATH
from ..utils._triton.kernel_repr import make_kernel_repr


_gemm_a16w16_repr = make_kernel_repr(
"_gemm_a16_w16_kernel",
[
"BLOCK_SIZE_M",
"BLOCK_SIZE_N",
"BLOCK_SIZE_K",
"GROUP_SIZE_M",
"NUM_KSPLIT",
"SPLITK_BLOCK_SIZE",
"EVEN_K",
"GRID_MN",
"cache_modifier",
"activation",
"use_activation",
"ADD_BIAS",
"SKIP_REDUCE",
],
)


@triton.heuristics(
Expand All @@ -16,7 +37,7 @@
* triton.cdiv(args["N"], args["BLOCK_SIZE_N"]),
}
)
@triton.jit
@triton.jit(repr=_gemm_a16w16_repr)
def _gemm_a16_w16_kernel(
a_ptr,
b_ptr,
Expand Down
44 changes: 44 additions & 0 deletions aiter/ops/triton/utils/_triton/kernel_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
def _sanitize_constexpr_value(value):
if value is None:
return "NONE"
if isinstance(value, bool):
return str(int(value))
if isinstance(value, int):
return str(value)
if isinstance(value, float):
if value.is_integer():
return str(int(value))
return str(value)

# for lists, tuples, sets - recursively join each
if isinstance(value, (list, tuple, set)):
items = sorted(value, key=str) if isinstance(value, set) else value
sanitized_items = [_sanitize_constexpr_value(item) for item in items]
joined = "_".join(sanitized_items)
return joined if joined else "NONE"

if isinstance(value, str):
cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in value).strip("_")
return cleaned_value.upper() if cleaned_value else "NONE"

cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in str(value)).strip("_")
return cleaned_value.upper() if cleaned_value else "NONE"


def make_kernel_repr(base_name, config_keys):
def _repr(specialization):
constants = specialization.constants
name_parts = []

for key in config_keys:
value = constants.get(key, None)
symbol = _sanitize_constexpr_value(value)
name_parts.append(f"{key}_{symbol}")

if not name_parts:
return base_name

suffix = "_".join(name_parts)
return f"{base_name}_{suffix}"

return _repr