diff --git a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py index 33281106eb..3561a9985f 100644 --- a/aiter/ops/triton/_triton_kernels/gemm_a16w16.py +++ b/aiter/ops/triton/_triton_kernels/gemm_a16w16.py @@ -6,6 +6,27 @@ from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH +from ..utils._triton.kernel_repr import make_kernel_repr + + +_gemm_a16w16_repr = make_kernel_repr( + "_gemm_a16_w16_kernel", + [ + "BLOCK_SIZE_M", + "BLOCK_SIZE_N", + "BLOCK_SIZE_K", + "GROUP_SIZE_M", + "NUM_KSPLIT", + "SPLITK_BLOCK_SIZE", + "EVEN_K", + "GRID_MN", + "cache_modifier", + "activation", + "use_activation", + "ADD_BIAS", + "SKIP_REDUCE", + ], +) @triton.heuristics( @@ -16,7 +37,7 @@ * triton.cdiv(args["N"], args["BLOCK_SIZE_N"]), } ) -@triton.jit +@triton.jit(repr=_gemm_a16w16_repr) def _gemm_a16_w16_kernel( a_ptr, b_ptr, diff --git a/aiter/ops/triton/utils/_triton/kernel_repr.py b/aiter/ops/triton/utils/_triton/kernel_repr.py new file mode 100644 index 0000000000..71f66eec93 --- /dev/null +++ b/aiter/ops/triton/utils/_triton/kernel_repr.py @@ -0,0 +1,44 @@ +def _sanitize_constexpr_value(value): + if value is None: + return "NONE" + if isinstance(value, bool): + return str(int(value)) + if isinstance(value, int): + return str(value) + if isinstance(value, float): + if value.is_integer(): + return str(int(value)) + return str(value) + + # for lists, tuples, sets - recursively join each + if isinstance(value, (list, tuple, set)): + items = sorted(value, key=str) if isinstance(value, set) else value + sanitized_items = [_sanitize_constexpr_value(item) for item in items] + joined = "_".join(sanitized_items) + return joined if joined else "NONE" + + if isinstance(value, str): + cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in value).strip("_") + return cleaned_value.upper() if cleaned_value else "NONE" + + cleaned_value = "".join(ch if ch.isalnum() else "_" for ch in str(value)).strip("_") + return cleaned_value.upper() if cleaned_value else "NONE" + + +def make_kernel_repr(base_name, config_keys): + def _repr(specialization): + constants = specialization.constants + name_parts = [] + + for key in config_keys: + value = constants.get(key, None) + symbol = _sanitize_constexpr_value(value) + name_parts.append(f"{key}_{symbol}") + + if not name_parts: + return base_name + + suffix = "_".join(name_parts) + return f"{base_name}_{suffix}" + + return _repr