From cffdc67ff6d78949cf67859ab1a69eac180090c5 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 16 Apr 2026 21:33:00 -0500 Subject: [PATCH 01/11] add small m hgemm --- aiter/aot/flydsl/gemm.py | 103 +- aiter/ops/flydsl/gemm_kernels.py | 256 +++- aiter/ops/flydsl/kernels/hgemm_dispatch.py | 76 ++ aiter/ops/flydsl/kernels/small_m_hgemm.py | 1387 ++++++++++++++++++++ aiter/ops/flydsl/kernels/splitk_hgemm.py | 21 +- aiter/tuned_gemm.py | 22 +- gradlib/gradlib/GemmTuner.py | 41 +- 7 files changed, 1805 insertions(+), 101 deletions(-) create mode 100644 aiter/ops/flydsl/kernels/hgemm_dispatch.py create mode 100644 aiter/ops/flydsl/kernels/small_m_hgemm.py diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index 2dafa8177b..bb0c0b93ff 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -39,8 +39,9 @@ from aiter.aot.flydsl.common import collect_aot_jobs, compile_only_env, job_identity from aiter.jit.core import AITER_CONFIGS from aiter.jit.utils.chip_info import get_gfx +from aiter.ops.flydsl.gemm_kernels import get_flydsl_splitk_hgemm_kernel_params +from aiter.ops.flydsl.kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel from aiter.ops.flydsl.kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 -from aiter.ops.flydsl.kernels.splitk_hgemm import compile_hgemm_kernel # Keep the default AOT coverage aligned with runtime config resolution. DEFAULT_CSVS = [ @@ -61,19 +62,6 @@ r"(?P\d+)x(?P\d+)x(?P\d+)x(?P\d+)_" r"(?P[A-Za-z0-9_]+)$" ) -_HGEMM_RE = re.compile( - r"^flydsl_gemm(?P\d+)_" - r"a(?P[a-z0-9]+)_w(?P[a-z0-9]+)_(?P[a-z0-9]+)_" - r"t(?P\d+)x(?P\d+)x(?P\d+)_" - r"split_k(?P\d+)_" - r"block_m_warp(?P\d+)_" - r"block_n_warp(?P\d+)_" - r"async_copy(?PTrue|False)_" - r"b_to_lds(?PTrue|False)_" - r"b_preshuffle(?PTrue|False)_" - r"c_to_lds(?PTrue|False)_" - r"(?Pgfx[0-9a-z]+)$" -) _SHORT_DTYPE = { "F8": "fp8", "I8": "int8", @@ -82,10 +70,15 @@ } -def _parse_bool(value: str) -> bool: - if value == "True": +def _parse_bool(value: Optional[str]) -> bool: + if value is None: + return False + normalized = value.strip().lower() + if normalized == "": + return False + if normalized in {"1", "true", "yes"}: return True - if value == "False": + if normalized in {"0", "false", "no"}: return False raise ValueError(f"Expected True/False, got {value!r}") @@ -120,37 +113,6 @@ def _parse_preshuffle_kernel_name(name: str) -> Optional[Dict]: } -def _parse_hgemm_kernel_name(name: str) -> Optional[Dict]: - m = _HGEMM_RE.fullmatch(name) - if m is None: - return None - - a_dtype = m.group("a_dtype") - w_dtype = m.group("w_dtype") - if a_dtype != w_dtype: - raise ValueError( - f"Unsupported mixed HGEMM input dtypes in {name!r}: {a_dtype} vs {w_dtype}" - ) - - return { - "kind": "hgemm", - "stage": int(m.group("stage")), - "dtype": a_dtype, - "out_dtype": m.group("out_dtype"), - "tile_m": int(m.group("tile_m")), - "tile_n": int(m.group("tile_n")), - "tile_k": int(m.group("tile_k")), - "split_k": int(m.group("split_k")), - "block_m_warps": int(m.group("block_m_warps")), - "block_n_warps": int(m.group("block_n_warps")), - "async_copy": _parse_bool(m.group("async_copy")), - "b_to_lds": _parse_bool(m.group("b_to_lds")), - "b_preshuffle": _parse_bool(m.group("b_preshuffle")), - "c_to_lds": _parse_bool(m.group("c_to_lds")), - "target_gfx": m.group("target_gfx"), - } - - def parse_csv(csv_path: str): """Parse a GEMM tuned CSV and return a list of unique FlyDSL compile jobs.""" jobs = [] @@ -161,7 +123,7 @@ def parse_csv(csv_path: str): for row in reader: kernel_name = row.get("kernelName", "").strip() libtype = row.get("libtype", "").strip() - if libtype != "flydsl" or not kernel_name.startswith("flydsl_"): + if libtype != "flydsl": continue m = int(row["M"]) @@ -171,7 +133,10 @@ def parse_csv(csv_path: str): if kernel_name.startswith("flydsl_bpreshuflle_"): params = _parse_preshuffle_kernel_name(kernel_name) elif kernel_name.startswith("flydsl_gemm"): - params = _parse_hgemm_kernel_name(kernel_name) + params = get_flydsl_splitk_hgemm_kernel_params(kernel_name) + if params is not None: + params = dict(params) + params["kind"] = "hgemm" else: params = None @@ -186,6 +151,7 @@ def parse_csv(csv_path: str): "m": m, "n": n, "k": k, + "has_bias": _parse_bool(row.get("bias")), **params, } key = job_identity(job) @@ -235,11 +201,17 @@ def _compile_hgemm_to_cache( split_k: int, block_m_warps: int, block_n_warps: int, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, async_copy: bool, b_to_lds: bool, b_preshuffle: bool, c_to_lds: bool, target_gfx: str, + kernel_family: str = "hgemm", + has_bias: bool = False, **kwargs, ): del kwargs, out_dtype @@ -259,6 +231,7 @@ def _compile_hgemm_to_cache( out = torch.empty((m, n), device=dev, dtype=torch_dtype) a = torch.empty((m, k), device=dev, dtype=torch_dtype) b = torch.empty((n, k), device=dev, dtype=torch_dtype) + bias = torch.empty((n,), device=dev, dtype=torch_dtype) counter = torch.zeros( (128 * 3,), device=dev, @@ -266,20 +239,30 @@ def _compile_hgemm_to_cache( ) stream = torch.cuda.current_stream(device=dev) - exe = compile_hgemm_kernel( + exe = compile_flydsl_hgemm_kernel( dtype, n, k, - TILE_M=tile_m, - TILE_N=tile_n, - TILE_K=tile_k, - SPLIT_K=split_k, - BLOCK_M_WARPS=block_m_warps, - BLOCK_N_WARPS=block_n_warps, - B_PRE_SHUFFLE=b_preshuffle, - B_TO_LDS=b_to_lds, + kernel_family=kernel_family, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + split_k=split_k, + block_m_warps=block_m_warps, + block_n_warps=block_n_warps, + n_tile_repeat=n_tile_repeat, + persistent_n_tiles=persistent_n_tiles, + waves_per_eu=waves_per_eu, + b_to_lds_unroll=b_to_lds_unroll, + async_copy=async_copy, + b_to_lds=b_to_lds, + b_preshuffle=b_preshuffle, + c_to_lds=c_to_lds, + has_bias=has_bias, + ) + _compile_executable_to_cache( + exe, out, a, b, bias if has_bias else None, m, counter, 0, stream ) - _compile_executable_to_cache(exe, out, a, b, m, counter, 0, stream) def _compile_preshuffle_to_cache( diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index 4aed782c04..887ea082a8 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -5,6 +5,7 @@ from __future__ import annotations +import re from itertools import product from typing import Dict, Optional @@ -18,7 +19,8 @@ from aiter.jit.utils.chip_info import get_gfx from ..shuffle import shuffle_weight -from .kernels.splitk_hgemm import compile_hgemm_kernel +from .kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel +from .kernels.small_m_hgemm import SMALL_M_KERNEL_MAX, iter_small_m_registry_configs from .kernels.tensor_shim import _run_compiled from .utils import get_shared_memory_per_block, is_flydsl_available @@ -31,6 +33,28 @@ FIXED_STAGE = 2 FIXED_C_TO_LDS = False KERNEL_ASYNC_COPY = get_rocm_arch() != "gfx942" +KERNEL_FAMILY_HGEMM = "hgemm" +KERNEL_FAMILY_SMALL_M = "small_m" +_HGEMM_KERNEL_RE = re.compile( + r"^flydsl_gemm(?P\d+)_" + r"a(?P[a-z0-9]+)_w(?P[a-z0-9]+)_(?P[a-z0-9]+)_" + r"t(?P\d+)x(?P\d+)x(?P\d+)_" + r"split_k(?P\d+)_" + r"block_m_warp(?P\d+)_" + r"block_n_warp(?P\d+)_" + r"async_copy(?PTrue|False)_" + r"b_to_lds(?PTrue|False)_" + r"b_preshuffle(?PTrue|False)_" + r"c_to_lds(?PTrue|False)" + r"(?P" + r"(?:_small_m)" + r"(?:_nr(?P\d+))?" + r"(?:_pn(?P\d+))?" + r"(?:_wpe(?P\d+))?" + r"(?:_ur(?P\d+))?" + r")?" + r"_(?Pgfx[0-9a-z]+)$" +) SplitKStreamKey = tuple[int, int] SPLIT_K_GLOBAL_SEMAPHORE: dict[SplitKStreamKey, torch.Tensor] = {} @@ -91,6 +115,11 @@ def flydsl_kernel_name( b_to_lds: bool, b_preshuffle: bool, c_to_lds: bool, + kernel_family: str = KERNEL_FAMILY_HGEMM, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, ) -> str: stage, async_copy, c_to_lds = _normalize_supported_kernel_metadata( stage=stage, @@ -109,6 +138,21 @@ def flydsl_kernel_name( f"_async_copy{async_copy}_b_to_lds{b_to_lds}_b_preshuffle{b_preshuffle}" f"_c_to_lds{c_to_lds}" ) + if kernel_family == KERNEL_FAMILY_SMALL_M: + name += "_small_m" + if n_tile_repeat > 1: + name += f"_nr{n_tile_repeat}" + if persistent_n_tiles > 1: + name += f"_pn{persistent_n_tiles}" + if waves_per_eu > 0: + name += f"_wpe{waves_per_eu}" + if b_to_lds_unroll > 0: + name += f"_ur{b_to_lds_unroll}" + elif kernel_family != KERNEL_FAMILY_HGEMM: + raise ValueError( + f"Unsupported kernel_family={kernel_family!r}; expected " + f"{KERNEL_FAMILY_HGEMM!r} or {KERNEL_FAMILY_SMALL_M!r}" + ) name += f"_{get_gfx()}" return name @@ -174,6 +218,7 @@ def _validate_hgemm_inputs( a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor], + bias: Optional[torch.Tensor], ) -> tuple[int, int, int]: if a.dim() != 2 or b.dim() != 2: raise ValueError( @@ -209,6 +254,18 @@ def _validate_hgemm_inputs( if not out.is_contiguous(): raise ValueError("`out` must be contiguous") + if bias is not None: + if bias.dim() != 1: + raise ValueError(f"`bias` must be 1D, got bias.dim={bias.dim()}") + if bias.shape != (n,): + raise ValueError(f"`bias` must have shape {(n,)}, got {tuple(bias.shape)}") + if bias.dtype != a.dtype: + raise ValueError( + f"`bias` dtype must match input dtype, got {bias.dtype=} {a.dtype=}" + ) + if bias.device != a.device: + raise ValueError(f"`bias` must be on {a.device}, got {bias.device}") + return m, n, k @@ -355,6 +412,7 @@ def _normalize_registry_config( b_preshuffle: bool, ) -> Optional[Dict]: config = { + "kernel_family": KERNEL_FAMILY_HGEMM, "stage": FIXED_STAGE, "tile_m": int(tile_m), "tile_n": int(tile_n), @@ -394,15 +452,68 @@ def _normalize_registry_config( return config +def _parse_hgemm_kernel_params(name: str) -> Optional[Dict]: + m = _HGEMM_KERNEL_RE.fullmatch(name) + if m is None: + return None + if m.group("a_dtype") != m.group("w_dtype"): + return None + + kernel_family = ( + KERNEL_FAMILY_SMALL_M + if m.group("small_m_suffix") is not None + else KERNEL_FAMILY_HGEMM + ) + config: Dict[str, object] = { + "kernel_family": kernel_family, + "stage": int(m.group("stage")), + "tile_m": int(m.group("tile_m")), + "tile_n": int(m.group("tile_n")), + "tile_k": int(m.group("tile_k")), + "split_k": int(m.group("split_k")), + "block_m_warps": int(m.group("block_m_warps")), + "block_n_warps": int(m.group("block_n_warps")), + "async_copy": m.group("async_copy") == "True", + "b_to_lds": m.group("b_to_lds") == "True", + "b_preshuffle": m.group("b_preshuffle") == "True", + "c_to_lds": m.group("c_to_lds") == "True", + "dtype": m.group("a_dtype"), + "out_dtype": m.group("out_dtype"), + "target_gfx": m.group("target_gfx"), + } + if kernel_family == KERNEL_FAMILY_SMALL_M: + config["n_tile_repeat"] = int(m.group("n_tile_repeat") or 1) + config["persistent_n_tiles"] = int(m.group("persistent_n_tiles") or 1) + config["waves_per_eu"] = int(m.group("waves_per_eu") or 0) + config["b_to_lds_unroll"] = int(m.group("b_to_lds_unroll") or 0) + return config + + def get_flydsl_splitk_hgemm_kernel_params(name: str) -> Optional[Dict]: config = _SPLITK_HGEMM_KERNELS.get(name) + if config is not None: + return dict(config) + config = _parse_hgemm_kernel_params(name) if config is not None: return dict(config) return None -def get_flydsl_splitk_hgemm_kernels(dtype: str, out_dtype: str) -> Dict[str, Dict]: +def get_flydsl_splitk_hgemm_kernels( + dtype: str, + out_dtype: str, + *, + m: Optional[int] = None, + n: Optional[int] = None, + k: Optional[int] = None, +) -> Dict[str, Dict]: kernels = {} + if any(dim is None for dim in (m, n, k)) and any( + dim is not None for dim in (m, n, k) + ): + raise ValueError( + "m, n, k must be provided together when requesting shape-aware kernels" + ) tile_ns = [64, 128, 256] tile_ks = [64, 128] tile_ms = [16, 32, 48, 64, 96, 128] @@ -431,6 +542,9 @@ def get_flydsl_splitk_hgemm_kernels(dtype: str, out_dtype: str) -> Dict[str, Dic ) if config is None: continue + config["dtype"] = dtype + config["out_dtype"] = out_dtype + config["target_gfx"] = get_gfx() name = flydsl_kernel_name( config["stage"], dtype, @@ -447,6 +561,38 @@ def get_flydsl_splitk_hgemm_kernels(dtype: str, out_dtype: str) -> Dict[str, Dic config["c_to_lds"], ) kernels[name] = config + if m is not None and n is not None and k is not None: + for config in ( + iter_small_m_registry_configs( + dtype, + out_dtype, + m=m, + n=n, + k=k, + ) + or () + ): + name = flydsl_kernel_name( + config["stage"], + dtype, + out_dtype, + config["tile_m"], + config["tile_n"], + config["tile_k"], + config["split_k"], + config["block_m_warps"], + config["block_n_warps"], + config["async_copy"], + config["b_to_lds"], + config["b_preshuffle"], + config["c_to_lds"], + kernel_family=KERNEL_FAMILY_SMALL_M, + n_tile_repeat=config["n_tile_repeat"], + persistent_n_tiles=config["persistent_n_tiles"], + waves_per_eu=config["waves_per_eu"], + b_to_lds_unroll=config["b_to_lds_unroll"], + ) + kernels[name] = config return kernels @@ -514,12 +660,18 @@ def _compile_flydsl_hgemm( tile_m: int = 128, tile_n: int = 128, pack_n: int = 1, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, stages: int = FIXED_STAGE, async_copy: bool = False, b_to_lds: bool = False, b_preshuffle: bool = True, split_k: int = 1, c_to_lds: bool = False, + kernel_family: str = KERNEL_FAMILY_HGEMM, + has_bias: bool = False, ): if dtype not in {"f16", "bf16"}: raise ValueError(f"`dtype` must be 'f16' or 'bf16', got {dtype!r}") @@ -530,35 +682,66 @@ def _compile_flydsl_hgemm( if c_to_lds: raise ValueError("Current kernel does not support `c_to_lds=True`") - _validate_hgemm_tiling( - m, + if kernel_family == KERNEL_FAMILY_HGEMM: + _validate_hgemm_tiling( + m, + n, + k, + dtype=dtype, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + pack_n=pack_n, + split_k=split_k, + stages=stages, + block_m_warps=block_m_warps, + block_n_warps=block_n_warps, + b_to_lds=b_to_lds, + ) + elif kernel_family == KERNEL_FAMILY_SMALL_M: + if dtype != "bf16": + raise ValueError(f"small-M kernel only supports `bf16`, got {dtype!r}") + if stages != FIXED_STAGE: + raise ValueError( + f"small-M kernel only supports stage={FIXED_STAGE}; got stage={stages}" + ) + if b_preshuffle: + raise ValueError("small-M kernel only supports `b_preshuffle=False`") + if tile_m != 16: + raise ValueError(f"small-M kernel fixes tile_m=16; got tile_m={tile_m}") + if block_m_warps != 1: + raise ValueError( + "small-M kernel fixes block_m_warps=1; " + f"got block_m_warps={block_m_warps}" + ) + else: + raise ValueError( + f"Unsupported kernel_family={kernel_family!r}; expected " + f"{KERNEL_FAMILY_HGEMM!r} or {KERNEL_FAMILY_SMALL_M!r}" + ) + + kernel = compile_flydsl_hgemm_kernel( + dtype, n, k, - dtype=dtype, + kernel_family=kernel_family, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, pack_n=pack_n, split_k=split_k, - stages=stages, block_m_warps=block_m_warps, block_n_warps=block_n_warps, + n_tile_repeat=n_tile_repeat, + persistent_n_tiles=persistent_n_tiles, + waves_per_eu=waves_per_eu, + b_to_lds_unroll=b_to_lds_unroll, + stages=stages, + async_copy=async_copy, b_to_lds=b_to_lds, - ) - - del async_copy - kernel = compile_hgemm_kernel( - dtype, - n, - k, - TILE_M=tile_m, - TILE_N=tile_n, - TILE_K=tile_k, - SPLIT_K=split_k, - BLOCK_M_WARPS=block_m_warps, - BLOCK_N_WARPS=block_n_warps, - B_PRE_SHUFFLE=b_preshuffle, - B_TO_LDS=b_to_lds, + b_preshuffle=b_preshuffle, + c_to_lds=c_to_lds, + has_bias=has_bias, ) def launcher( @@ -566,8 +749,19 @@ def launcher( a: torch.Tensor, b: torch.Tensor, signal_state: int, + bias: Optional[torch.Tensor] = None, stream: Optional[torch.cuda.Stream] = None, ): + if has_bias and bias is None: + raise ValueError( + "This launcher was compiled with bias support and requires `bias`." + ) + if not has_bias and bias is not None: + raise ValueError( + "This launcher was compiled without bias support; " + "recompile with `has_bias=True`." + ) + launch_bias = b if bias is None else bias runtime_m = int(a.shape[0]) _check_split_k_counter_capacity(runtime_m, n, tile_m, tile_n, split_k) launch_stream = _normalize_launch_stream(a.device, stream) @@ -577,6 +771,7 @@ def launcher( out, a, b, + launch_bias, runtime_m, semaphore, signal_state, @@ -591,6 +786,7 @@ def flydsl_hgemm( b: torch.Tensor, out: Optional[torch.Tensor] = None, *, + bias: Optional[torch.Tensor] = None, tile_m: int = 128, tile_n: int = 128, tile_k: int = 64, @@ -598,23 +794,30 @@ def flydsl_hgemm( split_k: int = 1, block_m_warps: int = 1, block_n_warps: int = 4, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, stages: int = FIXED_STAGE, async_copy: bool = False, b_to_lds: bool = False, b_preshuffle: bool = True, auto_shuffle_b: bool = False, c_to_lds: bool = False, + kernel_family: Optional[str] = None, stream: Optional[torch.cuda.Stream] = None, ) -> torch.Tensor: """Run FlyDSL HGEMM.""" - m, n, k = _validate_hgemm_inputs(a, b, out) + m, n, k = _validate_hgemm_inputs(a, b, out, bias) kernel_dtype = _to_kernel_dtype(a.dtype) if not a.is_contiguous(): a = a.contiguous() if not b.is_contiguous(): b = b.contiguous() + if bias is not None and not bias.is_contiguous(): + bias = bias.contiguous() if b_preshuffle and not getattr(b, "is_shuffled", False): if auto_shuffle_b: @@ -631,7 +834,6 @@ def flydsl_hgemm( launch_stream = _normalize_launch_stream(a.device, stream) signal_state = _get_split_k_signal_state(launch_stream) - launcher = _compile_flydsl_hgemm( kernel_dtype, m, @@ -643,15 +845,21 @@ def flydsl_hgemm( tile_m=tile_m, tile_n=tile_n, pack_n=pack_n, + n_tile_repeat=n_tile_repeat, + persistent_n_tiles=persistent_n_tiles, + waves_per_eu=waves_per_eu, + b_to_lds_unroll=b_to_lds_unroll, stages=stages, async_copy=async_copy, b_to_lds=b_to_lds, b_preshuffle=b_preshuffle, split_k=split_k, c_to_lds=c_to_lds, + kernel_family=(KERNEL_FAMILY_HGEMM if kernel_family is None else kernel_family), + has_bias=bias is not None, ) - launcher(out, a, b, signal_state=signal_state, stream=launch_stream) + launcher(out, a, b, signal_state=signal_state, bias=bias, stream=launch_stream) if split_k > 1: _advance_split_k_signal_state(launch_stream) return out diff --git a/aiter/ops/flydsl/kernels/hgemm_dispatch.py b/aiter/ops/flydsl/kernels/hgemm_dispatch.py new file mode 100644 index 0000000000..825dfa29f7 --- /dev/null +++ b/aiter/ops/flydsl/kernels/hgemm_dispatch.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Optional + +from .small_m_hgemm import compile_small_m_hgemm_kernel +from .splitk_hgemm import compile_hgemm_kernel + +KERNEL_FAMILY_HGEMM = "hgemm" +KERNEL_FAMILY_SMALL_M = "small_m" + + +def compile_flydsl_hgemm_kernel( + dtype: str, + n: int, + k: int, + *, + kernel_family: Optional[str] = None, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 64, + pack_n: int = 1, + split_k: int = 1, + block_m_warps: int = 1, + block_n_warps: int = 4, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, + stages: int = 2, + async_copy: bool = False, + b_to_lds: bool = False, + b_preshuffle: bool = True, + c_to_lds: bool = False, + has_bias: bool = False, +): + """Build one FlyDSL HGEMM-family kernel from a unified config surface.""" + + del pack_n, stages, async_copy, c_to_lds + + if kernel_family in (None, KERNEL_FAMILY_HGEMM): + return compile_hgemm_kernel( + dtype, + n, + k, + TILE_M=tile_m, + TILE_N=tile_n, + TILE_K=tile_k, + SPLIT_K=split_k, + BLOCK_M_WARPS=block_m_warps, + BLOCK_N_WARPS=block_n_warps, + B_PRE_SHUFFLE=b_preshuffle, + B_TO_LDS=b_to_lds, + HAS_BIAS=has_bias, + ) + + if kernel_family == KERNEL_FAMILY_SMALL_M: + return compile_small_m_hgemm_kernel( + dtype, + n, + k, + TILE_N=tile_n, + TILE_K=tile_k, + SPLIT_K=split_k, + BLOCK_N_WARPS=block_n_warps, + N_TILE_REPEAT=n_tile_repeat, + PERSISTENT_N_TILES=persistent_n_tiles, + WAVES_PER_EU_HINT=waves_per_eu, + B_TO_LDS_UNROLL=b_to_lds_unroll, + B_TO_LDS=b_to_lds, + HAS_BIAS=has_bias, + ) + + raise ValueError( + f"Unsupported kernel_family={kernel_family!r}; expected " + f"{KERNEL_FAMILY_HGEMM!r} or {KERNEL_FAMILY_SMALL_M!r}" + ) diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py new file mode 100644 index 0000000000..ffe6c1bb29 --- /dev/null +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -0,0 +1,1387 @@ +"""Dedicated small-M bf16 HGEMM kernel path. + +This module intentionally stays separate from `hgemm.py`. The generic HGEMM +kernel and this small-M path share the same split-K contract and both still +take `m` as a runtime value, but this path is no longer just a different +parameter point of one template: + +- `TILE_M=16` and `BLOCK_M_WARPS=1` are hard-wired so the block spends its + wave budget on N/K work instead of over-parallelizing the tiny M dimension. + Concretely, the block only covers one 16-row M tile and avoids launching + extra M-side warps whose useful work would quickly disappear once `m` is + much smaller than a generic HGEMM tile. +- Warp mapping is specialized for tiny-M shapes: warps do not spread across + the M dimension like the generic kernel, and more of the wave budget is used + to cover N-side work. In the hot path this shows up as `warp_m_idx = 0` and + `warp_n_idx = wid * WARP_N`, so the whole block behaves like "one small M + slice, many N workers" instead of a more balanced 2D warp decomposition. +- The kernel adds small-M-specific wide-N mechanisms: + `N_TILE_REPEAT` for non-`B_TO_LDS` multi-tile accumulation and + `PERSISTENT_N_TILES` for the `B_TO_LDS` persistent-N path. The first lets one + block reuse the same loaded A fragments while accumulating several N tiles in + registers; the second lets a `B_TO_LDS` block stay on a small group of N + tiles longer so the cost of setting up the tiny-M tile is amortized over more + useful N-side work. +- The `B_TO_LDS` hot loop is tuned separately with an explicit unroll knob and + a dedicated wide-N scheduler, rather than reusing the generic `hgemm.py` + scheduling structure. `B_TO_LDS_UNROLL` controls how many K iterations are + pipelined per outer step, and the wide-N scheduler adjusts the DS/VMEM/MFMA + issue pattern so LDS reads, async B loads, and matrix instructions stay + better balanced for these skinny-M / wide-N shapes. + +In practice, the main optimization goal here is to improve decode-like GEMMs +where M is tiny while N/K stay large: reduce wasted M-side parallelism, reuse +the loaded A tile across more N work, and give wide-N shapes a more specialized +schedule than the generic HGEMM kernel. +""" + +from __future__ import annotations + +import functools +import os +from itertools import product + +import flydsl.compiler as flyc +import flydsl.expr as fx +from aiter.jit.utils.chip_info import get_gfx +from flydsl._mlir import ir +from flydsl._mlir.dialects import fly, llvm, memref, scf +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.compiler.protocol import fly_values +from flydsl.expr import arith, gpu, range_constexpr, rocdl, vector +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from .splitk_hgemm import ( + OnlineScheduler, + SPLIT_K_COUNTER_MAX_LEN, + WmmaHalf_m16n16k32, + swizzle_xor16, +) +from .tensor_shim import GTensor, STensor, _to_raw, get_dtype_in_kernel + +__all__ = [ + "compile_small_m_hgemm_kernel", + "iter_small_m_registry_configs", + "SMALL_M_KERNEL_MAX", + "small_m_kernel_name", +] + +SMALL_M_KERNEL_MAX = 17 +TILE_M = 16 +BLOCK_M_WARPS = 1 +STAGES = 2 +WARP_SIZE = 64 +DTYPE_BYTES = 2 +LDG_VEC_SIZE = 8 +MAX_LDS_BYTES = 163840 + +# Default to a compact search space so offline tuning remains tractable. +# Set `AITER_FLYDSL_SMALL_M_SEARCH_SPACE=exhaustive` to recover the wider +# catalog used for deeper one-off searches. +SMALL_M_SEARCH_SPACE = ( + os.getenv("AITER_FLYDSL_SMALL_M_SEARCH_SPACE", "compact").strip().lower() +) +if SMALL_M_SEARCH_SPACE not in {"compact", "exhaustive"}: + raise ValueError( + "Unsupported AITER_FLYDSL_SMALL_M_SEARCH_SPACE=" + f"{SMALL_M_SEARCH_SPACE!r}; expected 'compact' or 'exhaustive'" + ) + +if SMALL_M_SEARCH_SPACE == "exhaustive": + SMALL_M_TILE_K_OPTIONS = [32, 64, 96, 128, 160, 192, 256] + SMALL_M_TILE_N_OPTIONS = [ + 32, + 64, + 96, + 128, + 160, + 192, + 224, + 256, + 384, + 512, + 768, + 1024, + ] + SMALL_M_SPLIT_K_OPTIONS = [1, 2, 4, 8, 16, 32] + SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] + SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] + SMALL_M_B_TO_LDS_UNROLL_OPTIONS = [0, 8, 16] + SMALL_M_N_TILE_REPEAT_OPTIONS = [1, 2, 4] + SMALL_M_PERSISTENT_N_TILE_OPTIONS = [2, 4, 8] + SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 3, 4) + SMALL_M_REPEAT_BLOCK_N_WARPS = (1, 2) + SMALL_M_B_TO_LDS_BLOCK_N_WARPS = (1, 2, 3, 4) + SMALL_M_PERSISTENT_BLOCK_N_WARPS = (2, 3, 4) +else: + SMALL_M_TILE_K_OPTIONS = [64, 128, 256] + SMALL_M_TILE_N_OPTIONS = [64, 128, 192, 256, 512, 1024] + SMALL_M_SPLIT_K_OPTIONS = [1, 2, 4, 8, 16] + SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0] + SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] + SMALL_M_B_TO_LDS_UNROLL_OPTIONS = [8, 16] + SMALL_M_N_TILE_REPEAT_OPTIONS = [1, 2] + SMALL_M_PERSISTENT_N_TILE_OPTIONS = [2, 4] + SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 4) + SMALL_M_REPEAT_BLOCK_N_WARPS = (1, 2) + SMALL_M_B_TO_LDS_BLOCK_N_WARPS = (1, 2, 4) + SMALL_M_PERSISTENT_BLOCK_N_WARPS = (2, 4) + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _align_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + +def small_m_kernel_name( + dtype: str, + *, + tile_n: int, + tile_k: int, + split_k: int, + block_n_warps: int, + n_tile_repeat: int, + persistent_n_tiles: int, + waves_per_eu: int, + b_to_lds_unroll: int, + b_to_lds: bool, + has_bias: bool, +) -> str: + name = ( + f"smallm_hgemm_{dtype}_{TILE_M}x{tile_n}x{tile_k}_S{STAGES}TN_AS" + f"_BNW{block_n_warps}" + ) + if n_tile_repeat > 1: + name += f"_NR{n_tile_repeat}" + if persistent_n_tiles > 1: + name += f"_PN{persistent_n_tiles}" + if split_k > 1: + name += f"_SPK{split_k}" + if b_to_lds: + name += "_BS" + if waves_per_eu > 0: + name += f"_WPE{waves_per_eu}" + if b_to_lds_unroll > 0: + name += f"_UR{b_to_lds_unroll}" + if has_bias: + name += "_BIAS" + return name + + +def _validate_small_m_registry_config( + m: int, + n: int, + k: int, + *, + tile_n: int, + tile_k: int, + split_k: int, + block_n_warps: int, + n_tile_repeat: int, + persistent_n_tiles: int, + waves_per_eu: int, + b_to_lds_unroll: int, + b_to_lds: bool, +) -> None: + del waves_per_eu + + if not (1 <= m < SMALL_M_KERNEL_MAX): + raise ValueError + if tile_n < 1 or tile_k < 32 or tile_k % 32 != 0: + raise ValueError + if block_n_warps < 1 or split_k < 1: + raise ValueError + if n_tile_repeat < 1 or persistent_n_tiles < 1: + raise ValueError + if b_to_lds_unroll < 0: + raise ValueError + if tile_n % (block_n_warps * 16) != 0: + raise ValueError + if n_tile_repeat > 1: + if b_to_lds: + raise ValueError + classic_repeat = block_n_warps == 1 and tile_n == 64 + wave_repeat = n_tile_repeat == 2 and block_n_warps == 2 and tile_n == 192 + if not (classic_repeat or wave_repeat): + raise ValueError + if persistent_n_tiles > 1: + if not b_to_lds or n_tile_repeat != 1 or tile_n < 128 or block_n_warps < 2: + raise ValueError + if n < tile_n or n % tile_n != 0: + raise ValueError + if persistent_n_tiles > n // tile_n: + raise ValueError + if k % split_k != 0: + raise ValueError + ks = k // split_k + if ks < tile_k or ks % tile_k != 0: + raise ValueError + + a_lds_bytes = max(2 * TILE_M * tile_k * DTYPE_BYTES, TILE_M * tile_n * DTYPE_BYTES) + lds_bytes = ( + a_lds_bytes + if not b_to_lds + else _align_up(a_lds_bytes, 16) + 2 * tile_n * tile_k * DTYPE_BYTES + ) + if lds_bytes > MAX_LDS_BYTES: + raise ValueError + + +def _small_m_registry_variants(): + variants = [] + seen_variants = set() + + def add_variant( + *, + block_n_warps: int, + b_to_lds: bool, + n_tile_repeat: int = 1, + persistent_n_tiles: int = 1, + waves_per_eu: int = 0, + b_to_lds_unroll: int = 0, + ) -> None: + variant = { + "block_m_warps": BLOCK_M_WARPS, + "block_n_warps": block_n_warps, + "b_to_lds": b_to_lds, + "n_tile_repeat": n_tile_repeat, + "persistent_n_tiles": persistent_n_tiles, + "waves_per_eu": waves_per_eu, + "b_to_lds_unroll": b_to_lds_unroll, + } + variant_key = tuple(sorted(variant.items())) + if variant_key in seen_variants: + return + seen_variants.add(variant_key) + variants.append(variant) + + for block_n_warps in SMALL_M_BASE_BLOCK_N_WARPS: + for waves_per_eu in SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS: + add_variant( + block_n_warps=block_n_warps, + b_to_lds=False, + waves_per_eu=waves_per_eu, + ) + + for n_tile_repeat in SMALL_M_N_TILE_REPEAT_OPTIONS[1:]: + for block_n_warps in SMALL_M_REPEAT_BLOCK_N_WARPS: + for waves_per_eu in SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS: + add_variant( + block_n_warps=block_n_warps, + b_to_lds=False, + n_tile_repeat=n_tile_repeat, + waves_per_eu=waves_per_eu, + ) + + for block_n_warps in SMALL_M_B_TO_LDS_BLOCK_N_WARPS: + for waves_per_eu in SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS: + for b_to_lds_unroll in SMALL_M_B_TO_LDS_UNROLL_OPTIONS: + add_variant( + block_n_warps=block_n_warps, + b_to_lds=True, + waves_per_eu=waves_per_eu, + b_to_lds_unroll=b_to_lds_unroll, + ) + + for persistent_n_tiles in SMALL_M_PERSISTENT_N_TILE_OPTIONS: + for block_n_warps in SMALL_M_PERSISTENT_BLOCK_N_WARPS: + for waves_per_eu in SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS: + for b_to_lds_unroll in SMALL_M_B_TO_LDS_UNROLL_OPTIONS: + add_variant( + block_n_warps=block_n_warps, + b_to_lds=True, + persistent_n_tiles=persistent_n_tiles, + waves_per_eu=waves_per_eu, + b_to_lds_unroll=b_to_lds_unroll, + ) + + return tuple(variants) + + +def iter_small_m_registry_configs( + dtype: str, + out_dtype: str, + *, + m: int, + n: int, + k: int, +): + if dtype != "bf16" or out_dtype != "bf16": + return + + gpu_arch = get_rocm_arch() + if gpu_arch == "gfx942" or not (1 <= m < SMALL_M_KERNEL_MAX): + return + + seen_configs = set() + for tile_n, tile_k, split_k in product( + SMALL_M_TILE_N_OPTIONS, + SMALL_M_TILE_K_OPTIONS, + SMALL_M_SPLIT_K_OPTIONS, + ): + for variant in _small_m_registry_variants(): + config = { + "kernel_family": "small_m", + "stage": STAGES, + "tile_m": TILE_M, + "tile_n": tile_n, + "tile_k": tile_k, + "split_k": split_k, + "block_m_warps": BLOCK_M_WARPS, + "block_n_warps": variant["block_n_warps"], + "n_tile_repeat": variant["n_tile_repeat"], + "persistent_n_tiles": variant["persistent_n_tiles"], + "waves_per_eu": variant["waves_per_eu"], + "b_to_lds_unroll": variant["b_to_lds_unroll"], + "async_copy": True, + "b_to_lds": variant["b_to_lds"], + "b_preshuffle": False, + "c_to_lds": False, + "dtype": dtype, + "out_dtype": out_dtype, + "target_gfx": get_gfx(), + } + try: + _validate_small_m_registry_config( + m, + n, + k, + tile_n=config["tile_n"], + tile_k=config["tile_k"], + split_k=config["split_k"], + block_n_warps=config["block_n_warps"], + n_tile_repeat=config["n_tile_repeat"], + persistent_n_tiles=config["persistent_n_tiles"], + waves_per_eu=config["waves_per_eu"], + b_to_lds_unroll=config["b_to_lds_unroll"], + b_to_lds=config["b_to_lds"], + ) + except ValueError: + continue + config_key = tuple(sorted(config.items())) + if config_key in seen_configs: + continue + seen_configs.add(config_key) + yield config + + +@functools.lru_cache(maxsize=512) +def compile_small_m_hgemm_kernel( + dtype: str, + n: int, + k: int, + *, + TILE_N: int = 128, + TILE_K: int = 64, + SPLIT_K: int = 1, + BLOCK_N_WARPS: int = 2, + N_TILE_REPEAT: int = 1, + PERSISTENT_N_TILES: int = 1, + WAVES_PER_EU_HINT: int = 0, + B_TO_LDS_UNROLL: int = 0, + B_TO_LDS: bool = False, + HAS_BIAS: bool = False, +): + if dtype != "bf16": + raise ValueError(f"`small_m_hgemm.py` only supports bf16, got {dtype!r}") + if SPLIT_K < 1: + raise ValueError(f"SPLIT_K must be >= 1, got {SPLIT_K}") + + GPU_ARCH = get_rocm_arch() + if GPU_ARCH == "gfx942": + raise ValueError("small-M kernel currently targets the async-copy bf16 path") + + WMMA_IMPL = WmmaHalf_m16n16k32(dtype) + DMA_BYTES = 16 + MFMA_PER_WARP_K = 1 + BLOCK_K = TILE_K + IS_SPLIT_K = SPLIT_K > 1 + assert (k % SPLIT_K == 0) and (k // SPLIT_K >= 1) + ks = k // SPLIT_K + assert (ks % BLOCK_K == 0) and (ks // BLOCK_K >= 1) + assert BLOCK_K >= 32 + + WMMA_M = WMMA_IMPL.WMMA_M + WMMA_N = WMMA_IMPL.WMMA_N + WMMA_K = WMMA_IMPL.WMMA_K + WMMA_A_FRAG_VALUES = WMMA_IMPL.WMMA_A_FRAG_VALUES + WMMA_B_FRAG_VALUES = WMMA_IMPL.WMMA_B_FRAG_VALUES + WMMA_C_FRAG_VALUES = WMMA_IMPL.WMMA_C_FRAG_VALUES + WARP_ATOM_M = WMMA_M + WARP_ATOM_N = WMMA_N + WARP_ATOM_K = WMMA_K * MFMA_PER_WARP_K + BLOCK_K_LOOPS = ks // BLOCK_K + WARP_K_STEPS = BLOCK_K // WARP_ATOM_K + assert (BLOCK_K % WARP_ATOM_K == 0) and (WARP_K_STEPS >= 1) + + BLOCK_THREADS = BLOCK_N_WARPS * WARP_SIZE + WARP_M_STEPS = TILE_M // BLOCK_M_WARPS // WARP_ATOM_M + WARP_N_STEPS = TILE_N // BLOCK_N_WARPS // WARP_ATOM_N + assert WARP_M_STEPS == 1 + assert (WARP_N_STEPS >= 1) and (TILE_N % (BLOCK_N_WARPS * WARP_ATOM_N) == 0) + + WARP_M = WARP_M_STEPS * WARP_ATOM_M + WARP_N = WARP_N_STEPS * WARP_ATOM_N + BLOCK_M = BLOCK_M_WARPS * WARP_M + BLOCK_N = BLOCK_N_WARPS * WARP_N + assert BLOCK_M == TILE_M + assert (n >= BLOCK_N) and (n % BLOCK_N == 0) + BLOCK_N_TILES = n // BLOCK_N + if N_TILE_REPEAT > 1: + if B_TO_LDS: + raise ValueError("wide-N repeat path only supports B_TO_LDS=False") + classic_repeat = BLOCK_N_WARPS == 1 and TILE_N == 64 + wave_repeat = N_TILE_REPEAT == 2 and BLOCK_N_WARPS == 2 and TILE_N == 192 + if not (classic_repeat or wave_repeat): + raise ValueError( + "wide-N repeat path requires either the classic " + "(BLOCK_N_WARPS=1, TILE_N=64, N_TILE_REPEAT>1) setup or the " + "wave-specialized (N_TILE_REPEAT=2, BLOCK_N_WARPS=2, TILE_N=192) setup" + ) + if PERSISTENT_N_TILES > 1: + if not B_TO_LDS: + raise ValueError("persistent-N path requires B_TO_LDS=True") + if N_TILE_REPEAT != 1: + raise ValueError("persistent-N path requires N_TILE_REPEAT=1") + if TILE_N < 128: + raise ValueError("persistent-N path currently requires TILE_N >= 128") + if BLOCK_N_WARPS < 2: + raise ValueError("persistent-N path currently requires BLOCK_N_WARPS >= 2") + if PERSISTENT_N_TILES > BLOCK_N_TILES: + raise ValueError( + "persistent-N path requires PERSISTENT_N_TILES <= total N tiles; " + f"got {PERSISTENT_N_TILES} > {BLOCK_N_TILES}" + ) + PERSISTENT_N = PERSISTENT_N_TILES > 1 + WIDE_N_B_TO_LDS = ( + B_TO_LDS and N_TILE_REPEAT == 1 and TILE_N >= 128 and BLOCK_N_WARPS >= 2 + ) + WAVES_PER_EU = ( + int(WAVES_PER_EU_HINT) + if WAVES_PER_EU_HINT > 0 + else (2 if WIDE_N_B_TO_LDS else 0) + ) + EFFECTIVE_B_TO_LDS_UNROLL = int(B_TO_LDS_UNROLL) if B_TO_LDS_UNROLL > 0 else 8 + + BLOCK_MK_SIZE = BLOCK_M * BLOCK_K + BLOCK_NK_SIZE = BLOCK_N * BLOCK_K + BLOCK_MN_SIZE = BLOCK_M * BLOCK_N + LDG_A_X_THREADS = BLOCK_K // LDG_VEC_SIZE + LDG_B_X_THREADS = BLOCK_K // LDG_VEC_SIZE + LDG_C_X_THREADS = BLOCK_N // LDG_VEC_SIZE + assert BLOCK_MK_SIZE % LDG_VEC_SIZE == 0 + assert BLOCK_NK_SIZE % LDG_VEC_SIZE == 0 + assert BLOCK_MN_SIZE % LDG_VEC_SIZE == 0 + LDG_A_TOTAL_VECS = BLOCK_MK_SIZE // LDG_VEC_SIZE + LDG_B_TOTAL_VECS = BLOCK_NK_SIZE // LDG_VEC_SIZE + LDG_C_TOTAL_VECS = BLOCK_MN_SIZE // LDG_VEC_SIZE + LDG_REG_A_COUNT = _ceil_div(LDG_A_TOTAL_VECS, BLOCK_THREADS) + LDG_REG_B_COUNT = _ceil_div(LDG_B_TOTAL_VECS, BLOCK_THREADS) + LDG_REG_C_COUNT = _ceil_div(LDG_C_TOTAL_VECS, BLOCK_THREADS) + assert (LDG_REG_A_COUNT >= 1) and (LDG_REG_B_COUNT >= 1) and (LDG_REG_C_COUNT >= 1) + + BLOCK_K_BYTES = BLOCK_K * DTYPE_BYTES + + allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem") + smem_a_offset = allocator._align(allocator.ptr, 16) + AS_BYTES = STAGES * BLOCK_M * BLOCK_K * DTYPE_BYTES + AS_BYTES = max(AS_BYTES, BLOCK_M * BLOCK_N * DTYPE_BYTES) + allocator.ptr = smem_a_offset + AS_BYTES + SMEM_USE = AS_BYTES + if B_TO_LDS: + smem_b_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = smem_b_offset + STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES + SMEM_USE += STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES + assert SMEM_USE <= MAX_LDS_BYTES + + LDG_ASYNC_VEC_SIZE = DMA_BYTES // DTYPE_BYTES + LDG_A_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE + LDG_B_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE + assert BLOCK_MK_SIZE % LDG_ASYNC_VEC_SIZE == 0 + assert BLOCK_NK_SIZE % LDG_ASYNC_VEC_SIZE == 0 + LDG_A_TOTAL_VECS_AS = BLOCK_MK_SIZE // LDG_ASYNC_VEC_SIZE + LDG_B_TOTAL_VECS_AS = BLOCK_NK_SIZE // LDG_ASYNC_VEC_SIZE + LDG_REG_A_COUNT_AS = _ceil_div(LDG_A_TOTAL_VECS_AS, BLOCK_THREADS) + LDG_REG_B_COUNT_AS = _ceil_div(LDG_B_TOTAL_VECS_AS, BLOCK_THREADS) + + KERNEL_NAME = small_m_kernel_name( + dtype, + tile_n=TILE_N, + tile_k=TILE_K, + split_k=SPLIT_K, + block_n_warps=BLOCK_N_WARPS, + n_tile_repeat=N_TILE_REPEAT, + persistent_n_tiles=PERSISTENT_N_TILES, + waves_per_eu=WAVES_PER_EU, + b_to_lds_unroll=EFFECTIVE_B_TO_LDS_UNROLL if B_TO_LDS else 0, + b_to_lds=B_TO_LDS, + has_bias=HAS_BIAS, + ) + + @flyc.kernel + def small_m_hgemm_kernel( + C: fx.Tensor, + A: fx.Tensor, + B: fx.Tensor, + BIAS: fx.Tensor, + m: fx.Int32, + COUNTER: fx.Tensor, + signal_state: fx.Int32, + ): + dtype_ = get_dtype_in_kernel(dtype) + _ptr_type = ir.Type.parse("!llvm.ptr<1>") + _i64_type = T.i64 + c_zero_d = arith.constant(0.0, type=dtype_) + acc_init = arith.constant_vector(0.0, T.vec(WMMA_C_FRAG_VALUES, T.f32)) + zero_a_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) + zero_a_async_vec = vector.broadcast(T.vec(LDG_ASYNC_VEC_SIZE, dtype_), c_zero_d) + + A_ = GTensor(A, dtype=dtype_, shape=(-1, k)) + B_ = GTensor(B, dtype=dtype_, shape=(n, k)) + C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) + if HAS_BIAS: + BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) + + base_ptr = allocator.get_base() + smem_a_ptr = SmemPtr( + base_ptr, + smem_a_offset, + dtype_, + shape=(STAGES * BLOCK_M * BLOCK_K,), + ) + as_ = STensor(smem_a_ptr, dtype_, shape=(STAGES, BLOCK_M, BLOCK_K)) + if B_TO_LDS: + smem_b_ptr = SmemPtr( + base_ptr, + smem_b_offset, + dtype_, + shape=(STAGES * BLOCK_N * BLOCK_K,), + ) + bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K)) + smem_c_ptr = SmemPtr( + base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,) + ) + cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) + if IS_SPLIT_K: + COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) + + tid = fx.Int32(fx.thread_idx.x) + wid = tid // WARP_SIZE + w_tid = tid % WARP_SIZE + block_m_idx = fx.block_idx.x + block_n_group_idx = fx.Index(fx.block_idx.y) + ks_idx = fx.Index(fx.block_idx.z) + ks_begin = arith.index_cast(T.i32, ks_idx * ks) + block_n_tiles = n // BLOCK_N + tile_group = PERSISTENT_N_TILES if PERSISTENT_N else N_TILE_REPEAT + + m_offset = fx.Index(block_m_idx * BLOCK_M) + tile_block_n_indices = [ + block_n_group_idx * fx.Index(tile_group) + fx.Index(tile_i) + for tile_i in range_constexpr(tile_group) + ] + tile_n_offsets = [ + tile_block_n_idx * fx.Index(BLOCK_N) + for tile_block_n_idx in tile_block_n_indices + ] + tile_actives = [ + arith.cmpi( + arith.CmpIPredicate.ult, + tile_block_n_idx, + fx.Index(block_n_tiles), + ) + for tile_block_n_idx in tile_block_n_indices + ] + tile_counter_indices = [ + fx.Int32(signal_state * SPLIT_K_COUNTER_MAX_LEN) + + fx.block_idx.x * fx.Int32(block_n_tiles) + + arith.index_cast(T.i32, tile_block_n_idx) + for tile_block_n_idx in tile_block_n_indices + ] + k_blocks16 = fx.Int32(BLOCK_K_BYTES // 16) + + warp_m_idx = fx.Int32(0) + warp_n_idx = wid * WARP_N + ldmatrix_a_m_idx = w_tid % WMMA_M + ldmatrix_a_k_vec_idx = w_tid // WMMA_M * WMMA_A_FRAG_VALUES * MFMA_PER_WARP_K + ldmatrix_b_n_idx = w_tid % WMMA_N + ldmatrix_b_k_vec_idx = w_tid // WMMA_N * WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K + + A_FRAGS_LEN = WARP_K_STEPS * WARP_M_STEPS + B_FRAGS_LEN = WARP_K_STEPS * WARP_N_STEPS + C_FRAGS_LEN = WARP_M_STEPS * WARP_N_STEPS + B_FRAG_T = T.vec(WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, dtype_) + zero_b_frag = vector.broadcast(B_FRAG_T, c_zero_d) + c_frags = [acc_init] * (C_FRAGS_LEN * N_TILE_REPEAT) + + def zero_c_tile(tile_n_offset): + zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_C_X_THREADS + n_local_idx = global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE + row_idx = m_offset + fx.Index(m_local_idx) + init_vec = zero_vec + if HAS_BIAS: + init_vec = BIAS_.vec_load( + (tile_n_offset + n_local_idx,), LDG_VEC_SIZE + ) + cond_boundary = arith.cmpi( + arith.CmpIPredicate.ult, row_idx, fx.Index(m) + ) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + C_.vec_store( + (row_idx, tile_n_offset + n_local_idx), init_vec, LDG_VEC_SIZE + ) + scf.YieldOp([]) + + def init_split_k_counter(tile_counter_idx): + is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) + is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) + with ir.InsertionPoint(is_t0_cond_if.then_block): + counter_base_ptr = fly.extract_aligned_pointer_as_index( + _ptr_type, fly_values(COUNTER)[0] + ) + counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result + counter_byte_offset = arith.index_cast( + T.i64, fx.Index(tile_counter_idx) * fx.Index(4) + ) + counter_ptr = llvm.AddOp( + counter_base_ptr, + counter_byte_offset, + llvm.IntegerOverflowFlags(0), + ).result + counter_ptr = llvm.IntToPtrOp(_ptr_type, counter_ptr).result + counter_ptr_v = ( + counter_ptr._value + if hasattr(counter_ptr, "_value") + else counter_ptr + ) + llvm.InlineAsmOp( + None, + [], + "buffer_wbl2 sc0 sc1", + "", + has_side_effects=True, + ) + llvm.InlineAsmOp( + None, + [counter_ptr_v, arith.constant(1, type=T.i32)], + "global_store_dword $0, $1, off sc0 sc1", + "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + scf.YieldOp([]) + + def cleanup_stale_counters_once(): + clean_cond = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(tid), + fx.Index(SPLIT_K_COUNTER_MAX_LEN), + ) + clean_cond_if = scf.IfOp(clean_cond, results_=[], has_else=False) + with ir.InsertionPoint(clean_cond_if.then_block): + clean_counter_idx = fx.Int32( + ((signal_state + 2) % 3) * SPLIT_K_COUNTER_MAX_LEN + ) + fx.Index(tid) + COUNTER_[fx.Index(clean_counter_idx)] = arith.constant(0, type=T.i32) + scf.YieldOp([]) + + def split_k_barrier(tile_counter_idx): + init_cur = arith.constant(0, type=T.i32) + w = scf.WhileOp([T.i32], [init_cur]) + before = ir.Block.create_at_start(w.before, [T.i32]) + after = ir.Block.create_at_start(w.after, [T.i32]) + with ir.InsertionPoint(before): + cur = before.arguments[0] + need_wait = arith.CmpIOp( + arith.CmpIPredicate.eq, cur, arith.constant(0, type=T.i32) + ).result + scf.ConditionOp(need_wait, [cur]) + with ir.InsertionPoint(after): + counter_base_ptr = fly.extract_aligned_pointer_as_index( + _ptr_type, fly_values(COUNTER)[0] + ) + counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result + counter_byte_offset = arith.index_cast( + T.i64, fx.Index(tile_counter_idx) * fx.Index(4) + ) + counter_ptr = llvm.AddOp( + counter_base_ptr, + counter_byte_offset, + llvm.IntegerOverflowFlags(0), + ).result + counter_ptr = llvm.IntToPtrOp(_ptr_type, counter_ptr).result + counter_ptr_v = ( + counter_ptr._value + if hasattr(counter_ptr, "_value") + else counter_ptr + ) + data = llvm.InlineAsmOp( + T.i32, + [counter_ptr_v], + "global_load_dword $0, $1, off sc1", + "=v,v", + has_side_effects=True, + ).result + rocdl.s_waitcnt(0) + scf.YieldOp([data]) + gpu.barrier() + + def ldg_a(k_offset): + vecs = [] + for i in range_constexpr(LDG_REG_A_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS + k_local_idx = global_tid % LDG_A_X_THREADS * LDG_VEC_SIZE + row_idx = m_offset + fx.Index(m_local_idx) + col_idx = fx.Index(k_offset + k_local_idx) + slot_valid = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(global_tid), + fx.Index(LDG_A_TOTAL_VECS), + ) + valid_row = arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m)) + can_load = arith.andi(slot_valid, valid_row) + load_if = scf.IfOp( + can_load, + results_=[T.vec(LDG_VEC_SIZE, dtype_)], + has_else=True, + ) + with ir.InsertionPoint(load_if.then_block): + scf.YieldOp([A_.vec_load((row_idx, col_idx), LDG_VEC_SIZE)]) + with ir.InsertionPoint(load_if.else_block): + scf.YieldOp([zero_a_vec]) + vecs.append(load_if.results[0]) + return vecs + + def sts_a(vecs, lds_stage): + for i in range_constexpr(LDG_REG_A_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS + k_local_idx = global_tid % LDG_A_X_THREADS * LDG_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(m_local_idx, col_in_bytes, k_blocks16) + slot_valid = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(global_tid), + fx.Index(LDG_A_TOTAL_VECS), + ) + store_if = scf.IfOp(slot_valid, results_=[], has_else=False) + with ir.InsertionPoint(store_if.then_block): + as_.vec_store( + (fx.Index(lds_stage), m_local_idx, col_in_bytes // DTYPE_BYTES), + vecs[i], + LDG_VEC_SIZE, + ) + scf.YieldOp([]) + + def ldg_sts_a_async(k_offset, lds_stage): + for i in range_constexpr(LDG_REG_A_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS_AS + k_local_idx = global_tid % LDG_A_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(m_local_idx, col_in_bytes, k_blocks16) + row_idx = m_offset + fx.Index(m_local_idx) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + slot_valid = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(global_tid), + fx.Index(LDG_A_TOTAL_VECS_AS), + ) + slot_if = scf.IfOp(slot_valid, results_=[], has_else=False) + with ir.InsertionPoint(slot_if.then_block): + valid_row = arith.cmpi( + arith.CmpIPredicate.ult, row_idx, fx.Index(m) + ) + cond_if = scf.IfOp(valid_row, results_=[], has_else=True) + with ir.InsertionPoint(cond_if.then_block): + global_offset = ( + A_.linear_offset((row_idx, col_idx)) * DTYPE_BYTES + ) + global_offset = arith.index_cast(T.i32, global_offset) + lds_offset = ( + as_.linear_offset( + (fx.Index(lds_stage), m_local_idx, k_local_idx) + ) + * DTYPE_BYTES + ) + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = ( + memref.extract_aligned_pointer_as_index(as_.memptr) + + lds_offset + ) + lds_addr_ = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + rocdl.raw_ptr_buffer_load_lds( + A_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + scf.YieldOp([]) + with ir.InsertionPoint(cond_if.else_block): + as_.vec_store( + (fx.Index(lds_stage), m_local_idx, k_local_idx), + zero_a_async_vec, + LDG_ASYNC_VEC_SIZE, + ) + scf.YieldOp([]) + scf.YieldOp([]) + + def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): + for i in range_constexpr(LDG_REG_B_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS_AS + k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(n_local_idx, col_in_bytes, k_blocks16) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + slot_valid = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(global_tid), + fx.Index(LDG_B_TOTAL_VECS_AS), + ) + slot_if = scf.IfOp(slot_valid, results_=[], has_else=False) + with ir.InsertionPoint(slot_if.then_block): + global_offset = B_.linear_offset( + (tile_n_offset + fx.Index(n_local_idx), col_idx) + ) + global_offset = arith.index_cast(T.i32, global_offset * DTYPE_BYTES) + lds_offset = ( + bs_.linear_offset( + (fx.Index(lds_stage), n_local_idx, k_local_idx) + ) + * DTYPE_BYTES + ) + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = ( + memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset + ) + lds_addr_ = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + rocdl.raw_ptr_buffer_load_lds( + B_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + scf.YieldOp([]) + + def lds_matrix_a(lds_stage): + s = fx.Index(lds_stage) + a_frags = [0] * A_FRAGS_LEN + for ii in range_constexpr(WARP_M_STEPS): + warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_m_idx + ldmatrix_a_m_idx + col_in_bytes = ( + warp_atom_k_idx + ldmatrix_a_k_vec_idx + ) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + vec = as_.vec_load( + (s, row, col_in_bytes // DTYPE_BYTES), + WMMA_A_FRAG_VALUES * MFMA_PER_WARP_K, + ) + a_frags[kk * WARP_M_STEPS + ii] = vec + return a_frags + + def lds_matrix_b(lds_stage): + s = fx.Index(lds_stage) + b_frags = [0] * B_FRAGS_LEN + for ii in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_n_idx + ldmatrix_b_n_idx + col_in_bytes = ( + warp_atom_k_idx + ldmatrix_b_k_vec_idx + ) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + vec = bs_.vec_load( + (s, row, col_in_bytes // DTYPE_BYTES), + WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, + ) + b_frags[kk * WARP_N_STEPS + ii] = vec + return b_frags + + def ldg_matrix_b(k_offset, tile_n_offset): + vecs = [] + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + for ii in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + n_idx = tile_n_offset + warp_atom_n_idx + ldmatrix_b_n_idx + k_idx = k_offset + warp_atom_k_idx + ldmatrix_b_k_vec_idx + vec = B_.vec_load( + (n_idx, k_idx), WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K + ) + vecs.append(vec) + return vecs + + def maybe_ldg_matrix_b(k_offset, tile_n_offset, tile_active): + if N_TILE_REPEAT == 1: + return ldg_matrix_b(k_offset, tile_n_offset) + load_if = scf.IfOp( + tile_active, + results_=[B_FRAG_T] * B_FRAGS_LEN, + has_else=True, + ) + with ir.InsertionPoint(load_if.then_block): + scf.YieldOp(ldg_matrix_b(k_offset, tile_n_offset)) + with ir.InsertionPoint(load_if.else_block): + scf.YieldOp([zero_b_frag] * B_FRAGS_LEN) + return list(load_if.results) + + def block_mma_sync(a_frags, b_frags, c_frags): + c_frags_new = [cx for cx in c_frags] + for kk in range_constexpr(WARP_K_STEPS): + for ii in range_constexpr(WARP_M_STEPS): + a_frag = a_frags[kk * WARP_M_STEPS + ii] + for jj in range_constexpr(WARP_N_STEPS): + b_frag = b_frags[kk * WARP_N_STEPS + jj] + c_idx = ii * WARP_N_STEPS + jj + c_frags_new[c_idx] = WMMA_IMPL( + a_frag, b_frag, c_frags_new[c_idx] + ) + return c_frags_new + + def store_split_k_tile(tile_n_offset): + out_raw = fly_values(C)[0] + out_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, out_raw) + out_base_int = llvm.PtrToIntOp(_i64_type, out_base_ptr).result + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) + n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) + m_global_idx = m_offset + m_local_idx + n_global_idx = tile_n_offset + n_local_idx + cond_boundary = arith.cmpi( + arith.CmpIPredicate.ult, m_global_idx, fx.Index(m) + ) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + pk_val = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + linear_bytes_offset = ( + C_.linear_offset((m_global_idx, n_global_idx)) * DTYPE_BYTES + ) + vec2_ty = T.vec(2, dtype_) + for vec_idx in range_constexpr(LDG_VEC_SIZE // 2): + e0 = vector.extract( + pk_val, + static_position=[vec_idx * 2], + dynamic_position=[], + ) + e1 = vector.extract( + pk_val, + static_position=[vec_idx * 2 + 1], + dynamic_position=[], + ) + pair = vector.from_elements(vec2_ty, [e0, e1]) + pair_byte_offset = arith.index_cast( + T.i64, + linear_bytes_offset + fx.Index(vec_idx * 2 * DTYPE_BYTES), + ) + pair_addr_i64 = llvm.AddOp( + out_base_int, + pair_byte_offset, + llvm.IntegerOverflowFlags(0), + ).result + pair_ptr = llvm.IntToPtrOp(_ptr_type, pair_addr_i64).result + pair_ptr_v = ( + pair_ptr._value if hasattr(pair_ptr, "_value") else pair_ptr + ) + pair_v = pair._value if hasattr(pair, "_value") else pair + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + pair_ptr_v, + pair_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=4, + ) + scf.YieldOp([]) + + def store_c_tile(tile_n_offset): + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) + n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) + m_global_idx = m_offset + m_local_idx + cond_boundary = arith.cmpi( + arith.CmpIPredicate.ult, m_global_idx, fx.Index(m) + ) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + if HAS_BIAS: + bias_vec = BIAS_.vec_load( + (tile_n_offset + n_local_idx,), LDG_VEC_SIZE + ) + vec = vec + bias_vec + C_.vec_store( + (m_global_idx, tile_n_offset + n_local_idx), vec, LDG_VEC_SIZE + ) + scf.YieldOp([]) + + stmatrix_c_m_vec_idx = w_tid // WMMA_N * WMMA_C_FRAG_VALUES + stmatrix_c_n_idx = w_tid % WMMA_N + + def write_c_frags_to_lds(tile_c_frags_): + for ii in range_constexpr(WARP_M_STEPS): + warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M + for jj in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + jj * WARP_ATOM_N + for kk in range_constexpr(WMMA_C_FRAG_VALUES): + lds_m_idx = fx.Index( + warp_atom_m_idx + stmatrix_c_m_vec_idx + kk + ) + lds_n_idx = fx.Index(warp_atom_n_idx + stmatrix_c_n_idx) + val = vector.extract( + tile_c_frags_[ii * WARP_N_STEPS + jj], + static_position=[kk], + dynamic_position=[], + ) + cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) + + if IS_SPLIT_K: + cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) + if not B_TO_LDS: + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + for tile_i in range_constexpr(N_TILE_REPEAT): + tile_init_if = scf.IfOp( + tile_actives[tile_i], results_=[], has_else=False + ) + with ir.InsertionPoint(tile_init_if.then_block): + zero_c_tile(tile_n_offsets[tile_i]) + scf.YieldOp([]) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + for tile_i in range_constexpr(N_TILE_REPEAT): + tile_init_if = scf.IfOp( + tile_actives[tile_i], results_=[], has_else=False + ) + with ir.InsertionPoint(tile_init_if.then_block): + init_split_k_counter(tile_counter_indices[tile_i]) + scf.YieldOp([]) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + cleanup_stale_counters_once() + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + if B_TO_LDS: + + def run_b_to_lds_tile(tile_n_offset, tile_counter_idx): + c_frags_local = [acc_init] * C_FRAGS_LEN + if IS_SPLIT_K: + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + zero_c_tile(tile_n_offset) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + init_split_k_counter(tile_counter_idx) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + ldg_sts_a_async(ks_begin, 0) + ldg_sts_b_async(ks_begin, 0, tile_n_offset) + gpu.barrier() + + def hot_loop_scheduler(): + MFMA_TOTAL = ( + WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K + ) + LDG_TOTAL = LDG_REG_A_COUNT_AS + LDG_REG_B_COUNT_AS + if WIDE_N_B_TO_LDS: + for _ in range_constexpr(WARP_K_STEPS * WARP_M_STEPS): + rocdl.sched_dsrd(1) + for _ in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): + rocdl.sched_dsrd(1) + for _ in range_constexpr(LDG_REG_A_COUNT_AS): + rocdl.sched_vmem(1) + rocdl.sched_mfma(2) + for _ in range_constexpr(LDG_REG_B_COUNT_AS): + rocdl.sched_vmem(1) + rocdl.sched_mfma(2) + remaining = max(MFMA_TOTAL - LDG_TOTAL * 2, 0) + for _ in range_constexpr(remaining): + rocdl.sched_mfma(1) + else: + for _ in range_constexpr(WARP_K_STEPS * WARP_M_STEPS): + rocdl.sched_dsrd(1) + for _ in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): + rocdl.sched_dsrd(1) + for _ in range_constexpr(LDG_TOTAL): + rocdl.sched_vmem(1) + rocdl.sched_mfma(2) + remaining = max(MFMA_TOTAL - LDG_TOTAL * 2, 0) + for _ in range_constexpr(remaining): + rocdl.sched_mfma(1) + rocdl.sched_barrier(0) + + UNROLL = EFFECTIVE_B_TO_LDS_UNROLL + init_state = [ks_begin, arith.constant(0, index=True)] + c_frags_local + for bki, state in range(0, BLOCK_K_LOOPS - 1, UNROLL, init=init_state): + k_offset = state[0] + current_stage = fx.Index(state[1]) + c_frags_local = state[2 : 2 + C_FRAGS_LEN] + for unroll_i in range_constexpr(UNROLL): + cond = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(bki + unroll_i), + fx.Index(BLOCK_K_LOOPS - 1), + ) + cond_if = scf.IfOp( + cond, + results_=[T.vec(WMMA_C_FRAG_VALUES, T.f32)] * C_FRAGS_LEN + + [T.index, T.i32], + has_else=True, + ) + with ir.InsertionPoint(cond_if.then_block): + next_stage = 1 - current_stage + a_frags = lds_matrix_a(current_stage) + b_frags = lds_matrix_b(current_stage) + ldg_sts_a_async(k_offset + BLOCK_K, next_stage) + ldg_sts_b_async( + k_offset + BLOCK_K, next_stage, tile_n_offset + ) + c_frags_new = block_mma_sync( + a_frags, b_frags, c_frags_local + ) + hot_loop_scheduler() + gpu.barrier() + k_offset_next = k_offset + fx.Int32(BLOCK_K) + current_stage_next = 1 - current_stage + scf.YieldOp( + c_frags_new + + [_to_raw(current_stage_next), k_offset_next] + ) + with ir.InsertionPoint(cond_if.else_block): + scf.YieldOp( + c_frags_local + [_to_raw(current_stage), k_offset] + ) + c_frags_local = [cond_if.results[i] for i in range(C_FRAGS_LEN)] + current_stage = cond_if.results[C_FRAGS_LEN] + k_offset = cond_if.results[C_FRAGS_LEN + 1] + results = yield [k_offset, current_stage] + c_frags_local + current_stage = results[1] + c_frags_local = results[2 : 2 + C_FRAGS_LEN] + a_frags = lds_matrix_a(current_stage) + b_frags = lds_matrix_b(current_stage) + c_frags_local = block_mma_sync(a_frags, b_frags, c_frags_local) + + write_c_frags_to_lds(c_frags_local) + gpu.barrier() + if IS_SPLIT_K: + split_k_barrier(tile_counter_idx) + store_split_k_tile(tile_n_offset) + else: + store_c_tile(tile_n_offset) + gpu.barrier() + + for tile_i in range_constexpr(tile_group): + tile_exec_if = scf.IfOp( + tile_actives[tile_i], results_=[], has_else=False + ) + with ir.InsertionPoint(tile_exec_if.then_block): + run_b_to_lds_tile( + tile_n_offsets[tile_i], tile_counter_indices[tile_i] + ) + scf.YieldOp([]) + else: + sts_a(ldg_a(ks_begin), 0) + gpu.barrier() + a_frags = lds_matrix_a(0) + b_frags = [] + for tile_i in range_constexpr(N_TILE_REPEAT): + b_frags.extend( + maybe_ldg_matrix_b( + ks_begin, + tile_n_offsets[tile_i], + tile_actives[tile_i], + ) + ) + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + MFMA_TOTAL = ( + N_TILE_REPEAT + * WARP_K_STEPS + * WARP_M_STEPS + * WARP_N_STEPS + * MFMA_PER_WARP_K + ) + LDG_TOTAL = ( + LDG_REG_A_COUNT_AS + N_TILE_REPEAT * WARP_K_STEPS * WARP_N_STEPS + ) + avg_mfma_count = (MFMA_TOTAL + LDG_TOTAL - 1) // LDG_TOTAL + mfma_sched = OnlineScheduler(MFMA_TOTAL, MFMA_TOTAL) + ldg_sched = OnlineScheduler(LDG_TOTAL, LDG_TOTAL) + for _ in range_constexpr(LDG_TOTAL): + rocdl.sched_vmem(ldg_sched.consume(1)) + rocdl.sched_mfma(mfma_sched.consume(avg_mfma_count)) + rocdl.sched_barrier(0) + + TOTAL_C_FRAGS_LEN = C_FRAGS_LEN * N_TILE_REPEAT + TOTAL_B_FRAGS_LEN = B_FRAGS_LEN * N_TILE_REPEAT + init_state = ( + [ks_begin, arith.constant(0, index=True)] + c_frags + a_frags + b_frags + ) + for _, state in range(1, BLOCK_K_LOOPS, init=init_state): + k_offset = state[0] + current_stage = fx.Index(state[1]) + next_stage = 1 - current_stage + c_frags = state[2 : 2 + TOTAL_C_FRAGS_LEN] + a_frags = state[ + 2 + TOTAL_C_FRAGS_LEN : 2 + TOTAL_C_FRAGS_LEN + A_FRAGS_LEN + ] + b_frags = state[ + 2 + + TOTAL_C_FRAGS_LEN + + A_FRAGS_LEN : 2 + + TOTAL_C_FRAGS_LEN + + A_FRAGS_LEN + + TOTAL_B_FRAGS_LEN + ] + ldg_sts_a_async(k_offset + BLOCK_K, next_stage) + b_frags_next = [] + c_frags_next = [] + for tile_i in range_constexpr(N_TILE_REPEAT): + b_start = tile_i * B_FRAGS_LEN + c_start = tile_i * C_FRAGS_LEN + b_frags_next.extend( + maybe_ldg_matrix_b( + k_offset + BLOCK_K, + tile_n_offsets[tile_i], + tile_actives[tile_i], + ) + ) + c_frags_next.extend( + block_mma_sync( + a_frags, + b_frags[b_start : b_start + B_FRAGS_LEN], + c_frags[c_start : c_start + C_FRAGS_LEN], + ) + ) + c_frags = c_frags_next + hot_loop_scheduler() + gpu.barrier() + a_frags_next = lds_matrix_a(next_stage) + k_offset = k_offset + fx.Int32(BLOCK_K) + rocdl.sched_barrier(0) + results = ( + yield [k_offset, next_stage] + c_frags + a_frags_next + b_frags_next + ) + c_frags = results[2 : 2 + TOTAL_C_FRAGS_LEN] + a_frags = results[ + 2 + TOTAL_C_FRAGS_LEN : 2 + TOTAL_C_FRAGS_LEN + A_FRAGS_LEN + ] + b_frags = results[ + 2 + + TOTAL_C_FRAGS_LEN + + A_FRAGS_LEN : 2 + + TOTAL_C_FRAGS_LEN + + A_FRAGS_LEN + + TOTAL_B_FRAGS_LEN + ] + c_frags_next = [] + for tile_i in range_constexpr(N_TILE_REPEAT): + b_start = tile_i * B_FRAGS_LEN + c_start = tile_i * C_FRAGS_LEN + c_frags_next.extend( + block_mma_sync( + a_frags, + b_frags[b_start : b_start + B_FRAGS_LEN], + c_frags[c_start : c_start + C_FRAGS_LEN], + ) + ) + c_frags = c_frags_next + + tile_c_frags = [ + c_frags[tile_i * C_FRAGS_LEN : (tile_i + 1) * C_FRAGS_LEN] + for tile_i in range_constexpr(N_TILE_REPEAT) + ] + + for tile_i in range_constexpr(N_TILE_REPEAT): + tile_store_if = scf.IfOp( + tile_actives[tile_i], results_=[], has_else=False + ) + with ir.InsertionPoint(tile_store_if.then_block): + write_c_frags_to_lds(tile_c_frags[tile_i]) + gpu.barrier() + if IS_SPLIT_K: + split_k_barrier(tile_counter_indices[tile_i]) + store_split_k_tile(tile_n_offsets[tile_i]) + else: + store_c_tile(tile_n_offsets[tile_i]) + gpu.barrier() + scf.YieldOp([]) + + @flyc.jit + def launch_small_m_hgemm_kernel( + C: fx.Tensor, + A: fx.Tensor, + B: fx.Tensor, + BIAS: fx.Tensor, + m: fx.Int32, + COUNTER: fx.Tensor, + signal_state: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + if WAVES_PER_EU > 0: + for op in ctx.gpu_module_body.operations: + if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + T.i32, int(WAVES_PER_EU) + ) + + bm = (m + BLOCK_M - 1) // BLOCK_M + tile_group = PERSISTENT_N_TILES if PERSISTENT_N else N_TILE_REPEAT + bn = (n // BLOCK_N + tile_group - 1) // tile_group + small_m_hgemm_kernel._func.__name__ = KERNEL_NAME + small_m_hgemm_kernel(C, A, B, BIAS, m, COUNTER, signal_state).launch( + grid=(bm, bn, SPLIT_K), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_small_m_hgemm_kernel diff --git a/aiter/ops/flydsl/kernels/splitk_hgemm.py b/aiter/ops/flydsl/kernels/splitk_hgemm.py index 773b6afc2c..6c4593441b 100644 --- a/aiter/ops/flydsl/kernels/splitk_hgemm.py +++ b/aiter/ops/flydsl/kernels/splitk_hgemm.py @@ -112,6 +112,7 @@ def compile_hgemm_kernel( BLOCK_N_WARPS: int = 4, B_PRE_SHUFFLE: bool = False, B_TO_LDS: bool = False, + HAS_BIAS: bool = False, ): IS_SPLIT_K = SPLIT_K > 1 BLOCK_K = TILE_K @@ -187,6 +188,8 @@ def compile_hgemm_kernel( KERNEL_NAME += f"_SPK{SPLIT_K}" if B_TO_LDS: KERNEL_NAME += "_BS" + if HAS_BIAS: + KERNEL_NAME += "_BIAS" allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem") smem_a_offset = allocator._align(allocator.ptr, 16) @@ -217,6 +220,7 @@ def hgemm_kernel( C: fx.Tensor, A: fx.Tensor, B: fx.Tensor, + BIAS: fx.Tensor, m: fx.Int32, COUNTER: fx.Tensor, signal_state: fx.Int32, @@ -230,6 +234,8 @@ def hgemm_kernel( A_ = GTensor(A, dtype=dtype_, shape=(-1, k)) B_ = GTensor(B, dtype=dtype_, shape=(n, k)) C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) + if HAS_BIAS: + BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) base_ptr = allocator.get_base() smem_a_ptr = SmemPtr( base_ptr, smem_a_offset, dtype_, shape=(STAGES * BLOCK_M * BLOCK_K,) @@ -296,6 +302,11 @@ def zero_c(): m_local_idx = global_tid // LDG_C_X_THREADS n_local_idx = global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE row_idx = m_offset + fx.Index(m_local_idx) + init_vec = zero_vec + if HAS_BIAS: + init_vec = BIAS_.vec_load( + (n_offset + n_local_idx,), LDG_VEC_SIZE + ) cond_boundary = arith.cmpi( arith.CmpIPredicate.ult, row_idx, fx.Index(m) ) @@ -304,7 +315,7 @@ def zero_c(): ) with ir.InsertionPoint(cond_boundary_if.then_block): C_.vec_store( - (row_idx, n_offset + n_local_idx), zero_vec, LDG_VEC_SIZE + (row_idx, n_offset + n_local_idx), init_vec, LDG_VEC_SIZE ) scf.YieldOp([]) scf.YieldOp([]) @@ -887,6 +898,11 @@ def hot_loop_scheduler(): cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + if HAS_BIAS: + bias_vec = BIAS_.vec_load( + (n_offset + n_local_idx,), LDG_VEC_SIZE + ) + vec = vec + bias_vec C_.vec_store( (m_global_idx, n_offset + n_local_idx), vec, LDG_VEC_SIZE ) @@ -898,6 +914,7 @@ def launch_hgemm_kernel( C: fx.Tensor, A: fx.Tensor, B: fx.Tensor, + BIAS: fx.Tensor, m: fx.Int32, COUNTER: fx.Tensor, signal_state: fx.Int32, @@ -911,7 +928,7 @@ def launch_hgemm_kernel( bm = (m + BLOCK_M - 1) // BLOCK_M bn = n // BLOCK_N hgemm_kernel._func.__name__ = KERNEL_NAME - hgemm_kernel(C, A, B, m, COUNTER, signal_state).launch( + hgemm_kernel(C, A, B, BIAS, m, COUNTER, signal_state).launch( grid=(bm, bn, SPLIT_K), block=(BLOCK_THREADS, 1, 1), stream=stream, diff --git a/aiter/tuned_gemm.py b/aiter/tuned_gemm.py index eb4aaef62f..503b83a158 100644 --- a/aiter/tuned_gemm.py +++ b/aiter/tuned_gemm.py @@ -441,22 +441,40 @@ def flydsl_gemm( assert ( scale_a is None and scale_b is None and scale_c is None ), "FlyDSL hgemm does not support scaling yet." + stages = config.get("stages", config.get("stage", 2)) + fused_bias = None + if ( + bias is not None + and (otype is None or otype == inp.dtype) + and bias.dtype == inp.dtype + ): + fused_bias = bias out = aiter.ops.flydsl.gemm_kernels.flydsl_hgemm( inp, weights, + bias=fused_bias, + kernel_family=config.get("kernel_family"), tile_m=config["tile_m"], tile_n=config["tile_n"], tile_k=config["tile_k"], split_k=config["split_k"], block_m_warps=config["block_m_warps"], block_n_warps=config["block_n_warps"], + n_tile_repeat=config.get("n_tile_repeat", 1), + persistent_n_tiles=config.get("persistent_n_tiles", 1), + waves_per_eu=config.get("waves_per_eu", 0), + b_to_lds_unroll=config.get("b_to_lds_unroll", 0), + stages=stages, + async_copy=config.get("async_copy", False), b_to_lds=config["b_to_lds"], b_preshuffle=config["b_preshuffle"], + c_to_lds=config.get("c_to_lds", False), ) + + if bias is not None and fused_bias is None: + out = out.to(bias.dtype) + bias if otype is not None and out.dtype != otype: out = out.to(otype) - if bias is not None: - out = out + bias return out diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 681f07d72a..2a4db4d332 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -97,40 +97,55 @@ def run_flydsl_gemm_bf16(input, weight, bias=None, otype=dtypes.bf16, config=Non raise RuntimeError(f"flydsl is not available for tuning: {FLYDSL_TUNE_ERROR}") if config is None: raise ValueError("flydsl tuning requires a kernel config") + stages = config.get("stages", config.get("stage", 2)) + fused_bias = None + if ( + bias is not None + and (otype is None or otype == input.dtype) + and bias.dtype == input.dtype + ): + fused_bias = bias out = flydsl_hgemm( input, weight, + bias=fused_bias, + kernel_family=config.get("kernel_family"), tile_m=config["tile_m"], tile_n=config["tile_n"], tile_k=config["tile_k"], split_k=config["split_k"], block_m_warps=config["block_m_warps"], block_n_warps=config["block_n_warps"], - stages=config["stage"], - async_copy=config["async_copy"], + n_tile_repeat=config.get("n_tile_repeat", 1), + persistent_n_tiles=config.get("persistent_n_tiles", 1), + waves_per_eu=config.get("waves_per_eu", 0), + b_to_lds_unroll=config.get("b_to_lds_unroll", 0), + stages=stages, + async_copy=config.get("async_copy", False), b_to_lds=config["b_to_lds"], b_preshuffle=config["b_preshuffle"], auto_shuffle_b=False, - c_to_lds=config["c_to_lds"], + c_to_lds=config.get("c_to_lds", False), ) + + if bias is not None and fused_bias is None: + out = out.to(bias.dtype) + bias if otype is not None and out.dtype != otype: out = out.to(otype) - if bias is not None: - if bias.dtype != out.dtype: - bias = bias.to(out.dtype) - out = out + bias return out -@lru_cache(maxsize=1) -def get_flydsl_bf16_catalog(): +@lru_cache(maxsize=1024) +def get_flydsl_bf16_catalog(m: int, n: int, k: int): if get_flydsl_splitk_hgemm_kernels is None: return [] - kernels = get_flydsl_splitk_hgemm_kernels("bf16", "bf16") + kernels = get_flydsl_splitk_hgemm_kernels("bf16", "bf16", m=m, n=n, k=k) catalog = [ (idx, name, dict(kernels[name])) for idx, name in enumerate(sorted(kernels)) ] - logger.info(f"FlyDSL bf16 catalog size: {len(catalog)} kernels") + logger.info( + f"FlyDSL bf16 catalog size for M={m}, N={n}, K={k}: {len(catalog)} kernels" + ) return catalog @@ -520,7 +535,7 @@ def flydsl_gemm_all_sols(self): return [] task = [] - flydsl_catalog = get_flydsl_bf16_catalog() + flydsl_catalog = get_flydsl_bf16_catalog(self.m, self.n, self.k) weight_idx = 6 if self.is_shuffle else 1 for solidx, kernel_name, config in flydsl_catalog: if config["b_preshuffle"] != self.is_shuffle: @@ -587,7 +602,7 @@ def flydsl_gemm_all_sols(self): logger.info( "FlyDSL candidate count for " f"M={self.m}, N={self.n}, K={self.k}, outdtype={self.outdtype}, " - f"bpreshuffle={self.is_shuffle}: {len(task)}/{len(flydsl_catalog)}" + f"bpreshuffle={self.is_shuffle}: {len(task)}" ) return task From f10ab1371c90ce1dcf2c7b76241c356ca2dec749 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 17 Apr 2026 04:01:37 -0500 Subject: [PATCH 02/11] update gptos config --- .../model_configs/gptoss_bf16_tuned_gemm.csv | 114 +++++----- .../gptoss_bf16_untuned_gemm_smallm.csv | 26 +++ ...ptoss_bf16_untuned_gemm_smallm_large_n.csv | 11 + aiter/ops/flydsl/gemm_kernels.py | 135 ++++++++---- aiter/ops/flydsl/kernels/small_m_hgemm.py | 195 +++++++++--------- 5 files changed, 283 insertions(+), 198 deletions(-) create mode 100644 aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv create mode 100644 aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv diff --git a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv index 1deaa684f9..5587bed4c7 100644 --- a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv +++ b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv @@ -1,58 +1,58 @@ gfx,cu_num,M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle,libtype,solidx,splitK,us,kernelName,err_ratio,tflops,bw -gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9558,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0234,0.15,149.99 -gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.9466,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,0.3,151.48 -gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9687,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.59,153.23 -gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9927,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0176,1.18,157.31 -gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,5.031,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0171,2.34,165.68 -gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.6354,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0203,5.09,200.59 -gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,14,5.2547,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0212,6.73,195.26 -gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,13,5.3561,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.021,8.81,209.54 -gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,13,5.6419,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0218,10.45,215.98 -gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,9,5.7166,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0163,12.38,230.0 -gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.9183,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0165,13.95,238.43 -gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.97,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0172,15.81,252.48 -gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,6,7.0187,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0126,26.89,324.47 -gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.6772,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,1.52,1524.87 -gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8371,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0107,3.0,1501.19 -gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8551,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0118,5.98,1500.66 -gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.9035,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0101,11.91,1497.72 -gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0897,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,23.38,1478.7 -gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.2307,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,46.12,1475.34 -gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7334,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,65.94,1422.46 -gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.6753,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.4,1446.51 -gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.6504,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,162.01,1385.21 -gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3742,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.53,1140.28 -gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.0917,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0083,1.3,1298.58 -gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.0871,auto,0.0,2.34,2340.31 -gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2607,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0056,2.55,1275.95 -gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,9.6663,auto,0.0,4.88,2443.63 -gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.1797,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0091,5.14,1289.36 -gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.1637,auto,0.0,9.29,2326.79 -gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.2315,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0092,10.22,1286.39 -gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.4653,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0132,16.46,2067.51 -gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.3253,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0062,20.24,1281.91 -gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.591,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0136,32.57,2054.71 -gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2671,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0065,40.73,1306.98 -gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.7821,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,64.08,2040.33 -gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,4,9.8145,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.009,57.69,1250.15 -gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1519,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.19,1996.61 -gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,4,10.1075,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.009,74.69,1229.51 -gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6689,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0137,119.19,1932.76 -gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,13.1447,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,143.59,1879.78 -gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.5787,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,166.8,1836.14 -gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.4374,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,151.54,1442.62 -gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.3635,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,132.88,1149.12 -gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.3233,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,174.33,1465.01 -gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.5453,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,194.26,921.15 -gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.5513,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,2.55,2554.45 -gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.4061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,5.17,2588.37 -gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.0536,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,9.79,2451.98 -gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.96,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,19.73,2476.52 -gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2044,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,38.66,2437.42 -gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.3457,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,76.44,2430.26 -gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,13.0203,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,108.72,2324.0 -gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.4435,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,140.4,2269.89 -gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,14.2374,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,165.71,2161.29 -gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.4747,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,195.59,2143.55 -gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.1801,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,172.21,1631.02 -gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3026,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,195.56,1633.94 +gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6262,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,0.16,160.67 +gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6019,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.32,162.83 +gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7202,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,0.62,161.29 +gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,1.25,166.89 +gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7651,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,2.48,174.93 +gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,42,15,4.6463,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0237,5.08,200.11 +gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,39,15,4.7651,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0262,7.43,215.33 +gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,188,15,4.7406,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0243,9.95,236.74 +gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,239,15,4.8623,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0253,12.13,250.61 +gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,242,15,5.3055,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0247,13.34,247.82 +gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,240,15,5.2015,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0246,15.88,271.28 +gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,205,9,5.6343,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0177,16.75,267.53 +gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,370,9,6.5478,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k9_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.018,28.83,347.81 +gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,8.5347,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0141,1.73,1729.0 +gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,7.904,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0152,3.73,1868.34 +gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,7.9144,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0266,7.45,1868.63 +gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,8.2005,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0251,14.39,1808.75 +gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,268,9,8.614,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,27.39,1732.03 +gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0478,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0099,46.96,1502.2 +gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.5933,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,66.81,1441.27 +gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7011,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.19,1443.02 +gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.7861,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,160.14,1369.26 +gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3912,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.26,1139.02 +gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,6.0968,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0212,1.93,1936.48 +gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,388,16,8.3638,flydsl_gemm2_abf16_wbf16_bf16_t16x192x64_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0316,2.82,2822.51 +gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,6.0853,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0224,3.88,1941.76 +gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,8.2195,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0269,5.74,2873.76 +gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,140,16,6.1308,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0224,7.7,1930.56 +gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,8.2605,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0292,11.42,2862.87 +gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,771,8,6.726,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0161,14.03,1765.59 +gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,792,8,8.6182,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_ur16_gfx950,0.0195,21.9,2750.53 +gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,750,4,6.9944,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0103,26.98,1709.11 +gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,798,8,9.0841,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0205,41.55,2621.74 +gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,145,4,8.0743,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0104,46.75,1500.05 +gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,164,4,10.4391,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0132,72.32,2302.83 +gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,353,4,9.5624,flydsl_gemm2_abf16_wbf16_bf16_t64x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0103,59.21,1283.11 +gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1652,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.09,1994.43 +gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,10.045,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0066,75.16,1237.16 +gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6691,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0136,119.18,1932.73 +gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,12.8662,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,146.7,1920.47 +gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.1474,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,172.27,1896.37 +gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.299,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,152.75,1454.16 +gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.5282,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,130.98,1132.7 +gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.0438,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,177.18,1489.04 +gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.42,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,195.84,928.64 +gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,10.4898,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,2.81,2812.94 +gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,9.6817,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0194,6.09,3049.38 +gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,10.2067,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0201,11.56,2895.67 +gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.2498,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0189,20.97,2632.86 +gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.9689,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0191,39.42,2485.37 +gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2241,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,77.2,2454.43 +gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,12.9649,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,109.19,2333.93 +gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.6686,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,138.09,2232.5 +gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,13.957,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,169.04,2204.71 +gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.3466,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,197.34,2162.69 +gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3332,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,170.85,1618.11 +gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.5488,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,193.1,1613.36 diff --git a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv new file mode 100644 index 0000000000..18a82fafff --- /dev/null +++ b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv @@ -0,0 +1,26 @@ +M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle +1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False +1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False +1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False +1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False +2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False +2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False +2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False +4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False +4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False +4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False +8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False +8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False +8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False +16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False +16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False +16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False diff --git a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv new file mode 100644 index 0000000000..c8c75e1213 --- /dev/null +++ b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv @@ -0,0 +1,11 @@ +M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle +1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False +16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False +16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index 887ea082a8..8d86c4d4e1 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -60,12 +60,37 @@ SPLIT_K_GLOBAL_SEMAPHORE: dict[SplitKStreamKey, torch.Tensor] = {} SPLIT_K_GLOBAL_SEMAPHORE_STATE: dict[SplitKStreamKey, int] = {} +# Expand the original default HGEMM catalog with the extra cases that proved +# useful during the wider one-off search, instead of maintaining separate +# search-space modes. +HGEMM_TILE_N_OPTIONS = (64, 128, 160, 192, 256) +HGEMM_TILE_K_OPTIONS = (64, 96, 128, 160, 256) +HGEMM_TILE_M_OPTIONS = (16, 32, 48, 64, 80, 96, 112, 128, 160, 256) +HGEMM_BASE_SPLIT_K_OPTIONS = (1, 2, 4, 8, 16) +HGEMM_MAX_SPLIT_K = 32 +HGEMM_EXTRA_BLOCK_K_LOOPS_MIN = 2 +HGEMM_EXTRA_BLOCK_K_LOOPS_MAX = 8 KERNEL_CONFIG_VARIANTS = ( + { + "block_m_warps": 1, + "block_n_warps": 2, + "b_to_lds": False, + }, { "block_m_warps": 1, "block_n_warps": 4, "b_to_lds": False, }, + { + "block_m_warps": 2, + "block_n_warps": 2, + "b_to_lds": False, + }, + { + "block_m_warps": 1, + "block_n_warps": 4, + "b_to_lds": True, + }, { "block_m_warps": 2, "block_n_warps": 2, @@ -188,6 +213,33 @@ def _align_up(value: int, alignment: int) -> int: return ((value + alignment - 1) // alignment) * alignment +def _hgemm_tile_m_options(m: Optional[int]) -> tuple[int, ...]: + if m is None: + return HGEMM_TILE_M_OPTIONS + max_tile_m = max(96, _align_up(max(1, m) * 2, 16)) + return tuple(tile_m for tile_m in HGEMM_TILE_M_OPTIONS if tile_m <= max_tile_m) + + +def _hgemm_split_k_options(k: Optional[int], tile_k: int) -> tuple[int, ...]: + if k is None: + return HGEMM_BASE_SPLIT_K_OPTIONS + options = set() + for split_k in range(1, HGEMM_MAX_SPLIT_K + 1): + if k % split_k != 0 or (k // split_k) % tile_k != 0: + continue + if split_k in HGEMM_BASE_SPLIT_K_OPTIONS: + options.add(split_k) + continue + block_k_loops = k // (split_k * tile_k) + if ( + HGEMM_EXTRA_BLOCK_K_LOOPS_MIN + <= block_k_loops + <= HGEMM_EXTRA_BLOCK_K_LOOPS_MAX + ): + options.add(split_k) + return tuple(sorted(options)) + + def _estimate_hgemm_lds_bytes( *, dtype: str, @@ -514,53 +566,54 @@ def get_flydsl_splitk_hgemm_kernels( raise ValueError( "m, n, k must be provided together when requesting shape-aware kernels" ) - tile_ns = [64, 128, 256] - tile_ks = [64, 128] - tile_ms = [16, 32, 48, 64, 96, 128] - split_ks = [1, 2, 4, 8, 16] b_preshuffles = [False, True] - - for tile_m, tile_n, tile_k, split_k, b_preshuffle, variant in product( + tile_ms = _hgemm_tile_m_options(m) + for tile_m, tile_n, tile_k, b_preshuffle, variant in product( tile_ms, - tile_ns, - tile_ks, - split_ks, + HGEMM_TILE_N_OPTIONS, + HGEMM_TILE_K_OPTIONS, b_preshuffles, KERNEL_CONFIG_VARIANTS, ): - config = _normalize_registry_config( - dtype=dtype, - stage=FIXED_STAGE, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - split_k=split_k, - block_m_warps=variant["block_m_warps"], - block_n_warps=variant["block_n_warps"], - b_to_lds=variant["b_to_lds"], - b_preshuffle=b_preshuffle, - ) - if config is None: + if n is not None and (n < tile_n or n % tile_n != 0): continue - config["dtype"] = dtype - config["out_dtype"] = out_dtype - config["target_gfx"] = get_gfx() - name = flydsl_kernel_name( - config["stage"], - dtype, - out_dtype, - config["tile_m"], - config["tile_n"], - config["tile_k"], - config["split_k"], - config["block_m_warps"], - config["block_n_warps"], - config["async_copy"], - config["b_to_lds"], - config["b_preshuffle"], - config["c_to_lds"], - ) - kernels[name] = config + split_k_options = _hgemm_split_k_options(k, tile_k) + if not split_k_options: + continue + for split_k in split_k_options: + config = _normalize_registry_config( + dtype=dtype, + stage=FIXED_STAGE, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + split_k=split_k, + block_m_warps=variant["block_m_warps"], + block_n_warps=variant["block_n_warps"], + b_to_lds=variant["b_to_lds"], + b_preshuffle=b_preshuffle, + ) + if config is None: + continue + config["dtype"] = dtype + config["out_dtype"] = out_dtype + config["target_gfx"] = get_gfx() + name = flydsl_kernel_name( + config["stage"], + dtype, + out_dtype, + config["tile_m"], + config["tile_n"], + config["tile_k"], + config["split_k"], + config["block_m_warps"], + config["block_n_warps"], + config["async_copy"], + config["b_to_lds"], + config["b_preshuffle"], + config["c_to_lds"], + ) + kernels[name] = config if m is not None and n is not None and k is not None: for config in ( iter_small_m_registry_configs( diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index ffe6c1bb29..14c06c4adc 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -38,8 +38,6 @@ from __future__ import annotations import functools -import os -from itertools import product import flydsl.compiler as flyc import flydsl.expr as fx @@ -77,57 +75,34 @@ LDG_VEC_SIZE = 8 MAX_LDS_BYTES = 163840 -# Default to a compact search space so offline tuning remains tractable. -# Set `AITER_FLYDSL_SMALL_M_SEARCH_SPACE=exhaustive` to recover the wider -# catalog used for deeper one-off searches. -SMALL_M_SEARCH_SPACE = ( - os.getenv("AITER_FLYDSL_SMALL_M_SEARCH_SPACE", "compact").strip().lower() +# Expand the original small-M catalog with the additional cases that proved +# useful during the deeper exhaustive search, instead of maintaining separate +# compact/exhaustive modes. +SMALL_M_TILE_K_OPTIONS = (32, 64, 96, 128, 160, 192, 256) +SMALL_M_MAX_SPLIT_K = 32 +SMALL_M_TILE_N_OPTIONS = ( + 32, + 64, + 96, + 128, + 160, + 192, + 224, + 256, + 384, + 512, + 768, + 1024, ) -if SMALL_M_SEARCH_SPACE not in {"compact", "exhaustive"}: - raise ValueError( - "Unsupported AITER_FLYDSL_SMALL_M_SEARCH_SPACE=" - f"{SMALL_M_SEARCH_SPACE!r}; expected 'compact' or 'exhaustive'" - ) - -if SMALL_M_SEARCH_SPACE == "exhaustive": - SMALL_M_TILE_K_OPTIONS = [32, 64, 96, 128, 160, 192, 256] - SMALL_M_TILE_N_OPTIONS = [ - 32, - 64, - 96, - 128, - 160, - 192, - 224, - 256, - 384, - 512, - 768, - 1024, - ] - SMALL_M_SPLIT_K_OPTIONS = [1, 2, 4, 8, 16, 32] - SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] - SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] - SMALL_M_B_TO_LDS_UNROLL_OPTIONS = [0, 8, 16] - SMALL_M_N_TILE_REPEAT_OPTIONS = [1, 2, 4] - SMALL_M_PERSISTENT_N_TILE_OPTIONS = [2, 4, 8] - SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 3, 4) - SMALL_M_REPEAT_BLOCK_N_WARPS = (1, 2) - SMALL_M_B_TO_LDS_BLOCK_N_WARPS = (1, 2, 3, 4) - SMALL_M_PERSISTENT_BLOCK_N_WARPS = (2, 3, 4) -else: - SMALL_M_TILE_K_OPTIONS = [64, 128, 256] - SMALL_M_TILE_N_OPTIONS = [64, 128, 192, 256, 512, 1024] - SMALL_M_SPLIT_K_OPTIONS = [1, 2, 4, 8, 16] - SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0] - SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = [0, 2, 4] - SMALL_M_B_TO_LDS_UNROLL_OPTIONS = [8, 16] - SMALL_M_N_TILE_REPEAT_OPTIONS = [1, 2] - SMALL_M_PERSISTENT_N_TILE_OPTIONS = [2, 4] - SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 4) - SMALL_M_REPEAT_BLOCK_N_WARPS = (1, 2) - SMALL_M_B_TO_LDS_BLOCK_N_WARPS = (1, 2, 4) - SMALL_M_PERSISTENT_BLOCK_N_WARPS = (2, 4) +SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) +SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) +SMALL_M_B_TO_LDS_UNROLL_OPTIONS = (0, 8, 16) +SMALL_M_N_TILE_REPEAT_OPTIONS = (1, 2, 4) +SMALL_M_PERSISTENT_N_TILE_OPTIONS = (2, 4, 8) +SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 3, 4) +SMALL_M_REPEAT_BLOCK_N_WARPS = (1, 2) +SMALL_M_B_TO_LDS_BLOCK_N_WARPS = (1, 2, 3, 4) +SMALL_M_PERSISTENT_BLOCK_N_WARPS = (2, 3, 4) def _ceil_div(x: int, y: int) -> int: @@ -138,6 +113,25 @@ def _align_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y +def _small_m_tile_k_options(k: int) -> tuple[int, ...]: + return tuple( + tile_k + for tile_k in SMALL_M_TILE_K_OPTIONS + if any( + k % split_k == 0 and (k // split_k) % tile_k == 0 + for split_k in range(1, SMALL_M_MAX_SPLIT_K + 1) + ) + ) + + +def _small_m_split_k_options(k: int, tile_k: int) -> tuple[int, ...]: + return tuple( + split_k + for split_k in range(1, SMALL_M_MAX_SPLIT_K + 1) + if k % split_k == 0 and (k // split_k) % tile_k == 0 + ) + + def small_m_kernel_name( dtype: str, *, @@ -319,55 +313,56 @@ def iter_small_m_registry_configs( return seen_configs = set() - for tile_n, tile_k, split_k in product( - SMALL_M_TILE_N_OPTIONS, - SMALL_M_TILE_K_OPTIONS, - SMALL_M_SPLIT_K_OPTIONS, - ): - for variant in _small_m_registry_variants(): - config = { - "kernel_family": "small_m", - "stage": STAGES, - "tile_m": TILE_M, - "tile_n": tile_n, - "tile_k": tile_k, - "split_k": split_k, - "block_m_warps": BLOCK_M_WARPS, - "block_n_warps": variant["block_n_warps"], - "n_tile_repeat": variant["n_tile_repeat"], - "persistent_n_tiles": variant["persistent_n_tiles"], - "waves_per_eu": variant["waves_per_eu"], - "b_to_lds_unroll": variant["b_to_lds_unroll"], - "async_copy": True, - "b_to_lds": variant["b_to_lds"], - "b_preshuffle": False, - "c_to_lds": False, - "dtype": dtype, - "out_dtype": out_dtype, - "target_gfx": get_gfx(), - } - try: - _validate_small_m_registry_config( - m, - n, - k, - tile_n=config["tile_n"], - tile_k=config["tile_k"], - split_k=config["split_k"], - block_n_warps=config["block_n_warps"], - n_tile_repeat=config["n_tile_repeat"], - persistent_n_tiles=config["persistent_n_tiles"], - waves_per_eu=config["waves_per_eu"], - b_to_lds_unroll=config["b_to_lds_unroll"], - b_to_lds=config["b_to_lds"], - ) - except ValueError: - continue - config_key = tuple(sorted(config.items())) - if config_key in seen_configs: + for tile_n in SMALL_M_TILE_N_OPTIONS: + for tile_k in _small_m_tile_k_options(k): + split_k_options = _small_m_split_k_options(k, tile_k) + if not split_k_options: continue - seen_configs.add(config_key) - yield config + for split_k in split_k_options: + for variant in _small_m_registry_variants(): + config = { + "kernel_family": "small_m", + "stage": STAGES, + "tile_m": TILE_M, + "tile_n": tile_n, + "tile_k": tile_k, + "split_k": split_k, + "block_m_warps": BLOCK_M_WARPS, + "block_n_warps": variant["block_n_warps"], + "n_tile_repeat": variant["n_tile_repeat"], + "persistent_n_tiles": variant["persistent_n_tiles"], + "waves_per_eu": variant["waves_per_eu"], + "b_to_lds_unroll": variant["b_to_lds_unroll"], + "async_copy": True, + "b_to_lds": variant["b_to_lds"], + "b_preshuffle": False, + "c_to_lds": False, + "dtype": dtype, + "out_dtype": out_dtype, + "target_gfx": get_gfx(), + } + try: + _validate_small_m_registry_config( + m, + n, + k, + tile_n=config["tile_n"], + tile_k=config["tile_k"], + split_k=config["split_k"], + block_n_warps=config["block_n_warps"], + n_tile_repeat=config["n_tile_repeat"], + persistent_n_tiles=config["persistent_n_tiles"], + waves_per_eu=config["waves_per_eu"], + b_to_lds_unroll=config["b_to_lds_unroll"], + b_to_lds=config["b_to_lds"], + ) + except ValueError: + continue + config_key = tuple(sorted(config.items())) + if config_key in seen_configs: + continue + seen_configs.add(config_key) + yield config @functools.lru_cache(maxsize=512) From ca27c93033b1962066482ee09bfaf8714a8e1007 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Sun, 19 Apr 2026 21:37:46 -0500 Subject: [PATCH 03/11] updata small_m_hgmm --- aiter/ops/flydsl/kernels/small_m_hgemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index 14c06c4adc..a7cb3354d0 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -365,7 +365,7 @@ def iter_small_m_registry_configs( yield config -@functools.lru_cache(maxsize=512) +@functools.lru_cache(maxsize=1024) def compile_small_m_hgemm_kernel( dtype: str, n: int, From 96579aaba0ad95f93947ea3efcf9cb52af6cbb58 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Sun, 19 Apr 2026 21:52:16 -0500 Subject: [PATCH 04/11] clear code --- aiter/aot/flydsl/gemm.py | 2 +- .../gptoss_bf16_untuned_gemm_smallm.csv | 26 ------------------- ...ptoss_bf16_untuned_gemm_smallm_large_n.csv | 11 -------- gradlib/gradlib/GemmTuner.py | 2 +- 4 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv delete mode 100644 aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index bb0c0b93ff..a8c5902a09 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -123,7 +123,7 @@ def parse_csv(csv_path: str): for row in reader: kernel_name = row.get("kernelName", "").strip() libtype = row.get("libtype", "").strip() - if libtype != "flydsl": + if libtype != "flydsl" or not kernel_name.startswith("flydsl_") continue m = int(row["M"]) diff --git a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv deleted file mode 100644 index 18a82fafff..0000000000 --- a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm.csv +++ /dev/null @@ -1,26 +0,0 @@ -M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle -1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False -1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False -1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False -1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False -2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False -2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False -2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False -4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False -4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False -4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False -8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False -8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False -8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False -16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False -16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False -16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False diff --git a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv b/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv deleted file mode 100644 index c8c75e1213..0000000000 --- a/aiter/configs/model_configs/gptoss_bf16_untuned_gemm_smallm_large_n.csv +++ /dev/null @@ -1,11 +0,0 @@ -M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle -1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False -16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False -16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index 2a4db4d332..e99c63adec 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -135,7 +135,7 @@ def run_flydsl_gemm_bf16(input, weight, bias=None, otype=dtypes.bf16, config=Non return out -@lru_cache(maxsize=1024) +@lru_cache(maxsize=1) def get_flydsl_bf16_catalog(m: int, n: int, k: int): if get_flydsl_splitk_hgemm_kernels is None: return [] From f5bd6d2bf4309a3356cb970900b7f54f96106f16 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Sun, 19 Apr 2026 22:14:06 -0500 Subject: [PATCH 05/11] code format --- aiter/ops/flydsl/gemm_kernels.py | 2 +- aiter/ops/flydsl/kernels/moe_gemm_2stage.py | 30 ++++++--------------- aiter/ops/flydsl/kernels/small_m_hgemm.py | 1 - gradlib/gradlib/GemmTuner.py | 2 +- 4 files changed, 10 insertions(+), 25 deletions(-) diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index 8d86c4d4e1..777525fdec 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -20,7 +20,7 @@ from ..shuffle import shuffle_weight from .kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel -from .kernels.small_m_hgemm import SMALL_M_KERNEL_MAX, iter_small_m_registry_configs +from .kernels.small_m_hgemm import iter_small_m_registry_configs from .kernels.tensor_shim import _run_compiled from .utils import get_shared_memory_per_block, is_flydsl_available diff --git a/aiter/ops/flydsl/kernels/moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/moe_gemm_2stage.py index f1630010bb..8b5308380d 100644 --- a/aiter/ops/flydsl/kernels/moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/moe_gemm_2stage.py @@ -38,7 +38,7 @@ def bf16_global_atomics_arch_description() -> str: from flydsl._mlir import ir -from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects import llvm, scf from flydsl.expr.typing import T @@ -219,14 +219,9 @@ def out_mlir(): "stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')" ) - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" - # IMPORTANT: module name participates in FlyDSL's compile cache key. - # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. - module_name = ( - f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" - f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults - ).replace("-", "_") + # IMPORTANT: FlyDSL compile cache keys should include an explicit ABI tag so + # signature changes cannot reuse an old binary. Intended tag pattern: + # mfma_moe1_{in_dtype}_{out_dtype}_{cshuffle|direct}_t{tile_m}x{tile_n}x{tile_k}_abi3 # ── LDS sizing (pure Python; no MLIR Context needed) ───────────────────── # Reuse the same LDS bytes for both: @@ -1076,7 +1071,7 @@ def hot_loop_scheduler(): # Epilogue hoists to keep IR + Python build time small: col_i32_list = [] for ni in range_constexpr(num_acc_n): - col_i32_list.append(arith.index_cast(i32, col_g_list[ni])) + col_i32_list.append(arith.index_cast(T.i32, col_g_list[ni])) inter_i32_local = inter_i32_v @@ -1525,18 +1520,9 @@ def out_elem(): ty = T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) return ty() if callable(ty) else ty - epilog_tag = "cshuffle" - # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled - # binary for a different (tile_m, tile_n, tile_k) configuration. - # See stage1 note: include ABI tag to prevent binary reuse across signature changes. - # IMPORTANT: module name participates in FlyDSL's compile cache key. - # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. - # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. - module_name = ( - f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" - f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi2" # mask sentinel token ids on loads/stores to avoid illegal address faults - ).replace("-", "_") + # IMPORTANT: include tiling in any future module/cache tag to avoid reusing a compiled + # binary for a different (tile_m, tile_n, tile_k). Stage2 dynamic-shape tag pattern: + # mfma_moe2_{in_dtype}_{out_s}_cshuffle_t{tile_m}x{tile_n}x{tile_k}_abi2 # ── CShuffle epilogue e_vec (pure Python; must be computed before @flyc.kernel # because the AST rewriter intercepts `if` statements inside kernel bodies and diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index a7cb3354d0..e0d03f5618 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -467,7 +467,6 @@ def compile_small_m_hgemm_kernel( BLOCK_NK_SIZE = BLOCK_N * BLOCK_K BLOCK_MN_SIZE = BLOCK_M * BLOCK_N LDG_A_X_THREADS = BLOCK_K // LDG_VEC_SIZE - LDG_B_X_THREADS = BLOCK_K // LDG_VEC_SIZE LDG_C_X_THREADS = BLOCK_N // LDG_VEC_SIZE assert BLOCK_MK_SIZE % LDG_VEC_SIZE == 0 assert BLOCK_NK_SIZE % LDG_VEC_SIZE == 0 diff --git a/gradlib/gradlib/GemmTuner.py b/gradlib/gradlib/GemmTuner.py index e99c63adec..bdbcac5681 100644 --- a/gradlib/gradlib/GemmTuner.py +++ b/gradlib/gradlib/GemmTuner.py @@ -127,7 +127,7 @@ def run_flydsl_gemm_bf16(input, weight, bias=None, otype=dtypes.bf16, config=Non auto_shuffle_b=False, c_to_lds=config.get("c_to_lds", False), ) - + if bias is not None and fused_bias is None: out = out.to(bias.dtype) + bias if otype is not None and out.dtype != otype: From 68a4258bcd93dd81c697e266ed3b21732fd437cf Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Sun, 19 Apr 2026 22:24:12 -0500 Subject: [PATCH 06/11] fix code format --- aiter/aot/flydsl/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index a8c5902a09..12c8b807fa 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -123,7 +123,7 @@ def parse_csv(csv_path: str): for row in reader: kernel_name = row.get("kernelName", "").strip() libtype = row.get("libtype", "").strip() - if libtype != "flydsl" or not kernel_name.startswith("flydsl_") + if libtype != "flydsl" or not kernel_name.startswith("flydsl_"): continue m = int(row["M"]) From ee63ac29fbe386a545fc7254744089857b92bca7 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Sun, 19 Apr 2026 23:31:59 -0500 Subject: [PATCH 07/11] remove moe_gemm_2stage.py change --- aiter/ops/flydsl/kernels/moe_gemm_2stage.py | 30 +++++++++++++++------ 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/aiter/ops/flydsl/kernels/moe_gemm_2stage.py b/aiter/ops/flydsl/kernels/moe_gemm_2stage.py index 8b5308380d..f1630010bb 100644 --- a/aiter/ops/flydsl/kernels/moe_gemm_2stage.py +++ b/aiter/ops/flydsl/kernels/moe_gemm_2stage.py @@ -38,7 +38,7 @@ def bf16_global_atomics_arch_description() -> str: from flydsl._mlir import ir -from flydsl._mlir.dialects import llvm, scf +from flydsl._mlir.dialects import llvm, scf, memref from flydsl.expr.typing import T @@ -219,9 +219,14 @@ def out_mlir(): "stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')" ) - # IMPORTANT: FlyDSL compile cache keys should include an explicit ABI tag so - # signature changes cannot reuse an old binary. Intended tag pattern: - # mfma_moe1_{in_dtype}_{out_dtype}_{cshuffle|direct}_t{tile_m}x{tile_n}x{tile_k}_abi3 + epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" + # IMPORTANT: module name participates in FlyDSL's compile cache key. + # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. + module_name = ( + f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" + f"_t{tile_m}x{tile_n}x{tile_k}" + f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults + ).replace("-", "_") # ── LDS sizing (pure Python; no MLIR Context needed) ───────────────────── # Reuse the same LDS bytes for both: @@ -1071,7 +1076,7 @@ def hot_loop_scheduler(): # Epilogue hoists to keep IR + Python build time small: col_i32_list = [] for ni in range_constexpr(num_acc_n): - col_i32_list.append(arith.index_cast(T.i32, col_g_list[ni])) + col_i32_list.append(arith.index_cast(i32, col_g_list[ni])) inter_i32_local = inter_i32_v @@ -1520,9 +1525,18 @@ def out_elem(): ty = T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) return ty() if callable(ty) else ty - # IMPORTANT: include tiling in any future module/cache tag to avoid reusing a compiled - # binary for a different (tile_m, tile_n, tile_k). Stage2 dynamic-shape tag pattern: - # mfma_moe2_{in_dtype}_{out_s}_cshuffle_t{tile_m}x{tile_n}x{tile_k}_abi2 + epilog_tag = "cshuffle" + # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled + # binary for a different (tile_m, tile_n, tile_k) configuration. + # See stage1 note: include ABI tag to prevent binary reuse across signature changes. + # IMPORTANT: module name participates in FlyDSL's compile cache key. + # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. + # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. + module_name = ( + f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" + f"_t{tile_m}x{tile_n}x{tile_k}" + f"_abi2" # mask sentinel token ids on loads/stores to avoid illegal address faults + ).replace("-", "_") # ── CShuffle epilogue e_vec (pure Python; must be computed before @flyc.kernel # because the AST rewriter intercepts `if` statements inside kernel bodies and From d91f0a5e9131b23ef4c1e63c622d6f38fc4a8eb9 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 20 Apr 2026 03:29:58 -0500 Subject: [PATCH 08/11] refact code --- aiter/aot/flydsl/gemm.py | 6 +- aiter/ops/flydsl/kernels/small_m_hgemm.py | 179 +++++++++++----------- aiter/ops/flydsl/kernels/splitk_hgemm.py | 163 ++++++++++---------- 3 files changed, 171 insertions(+), 177 deletions(-) diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index 12c8b807fa..6163d6c77a 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -260,9 +260,9 @@ def _compile_hgemm_to_cache( c_to_lds=c_to_lds, has_bias=has_bias, ) - _compile_executable_to_cache( - exe, out, a, b, bias if has_bias else None, m, counter, 0, stream - ) + # FlyDSL JIT does not accept None for tensor slots; pass a real buffer when + # bias fusion is disabled (matches runtime launcher dummy tensor behavior). + _compile_executable_to_cache(exe, out, a, b, bias, m, counter, 0, stream) def _compile_preshuffle_to_cache( diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index e0d03f5618..fda67e801f 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -95,8 +95,10 @@ 1024, ) SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) -SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) -SMALL_M_B_TO_LDS_UNROLL_OPTIONS = (0, 8, 16) +# Avoid catalog entries that normalize to the same kernel (hint 0 becomes 2 / 8 +# for wide-N B_TO_LDS in compile_small_m_hgemm_kernel). +SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = (2, 4) +SMALL_M_B_TO_LDS_UNROLL_OPTIONS = (8, 16) SMALL_M_N_TILE_REPEAT_OPTIONS = (1, 2, 4) SMALL_M_PERSISTENT_N_TILE_OPTIONS = (2, 4, 8) SMALL_M_BASE_BLOCK_N_WARPS = (1, 2, 3, 4) @@ -538,8 +540,7 @@ def small_m_hgemm_kernel( A_ = GTensor(A, dtype=dtype_, shape=(-1, k)) B_ = GTensor(B, dtype=dtype_, shape=(n, k)) C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) - if HAS_BIAS: - BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) + BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) base_ptr = allocator.get_base() smem_a_ptr = SmemPtr( @@ -561,8 +562,7 @@ def small_m_hgemm_kernel( base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,) ) cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) - if IS_SPLIT_K: - COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) + COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) tid = fx.Int32(fx.thread_idx.x) wid = tid // WARP_SIZE @@ -613,7 +613,7 @@ def small_m_hgemm_kernel( zero_b_frag = vector.broadcast(B_FRAG_T, c_zero_d) c_frags = [acc_init] * (C_FRAGS_LEN * N_TILE_REPEAT) - def zero_c_tile(tile_n_offset): + def zero_c_tile(bias_g, tile_n_offset): zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) for i in range_constexpr(LDG_REG_C_COUNT): global_tid = BLOCK_THREADS * i + tid @@ -622,7 +622,7 @@ def zero_c_tile(tile_n_offset): row_idx = m_offset + fx.Index(m_local_idx) init_vec = zero_vec if HAS_BIAS: - init_vec = BIAS_.vec_load( + init_vec = bias_g.vec_load( (tile_n_offset + n_local_idx,), LDG_VEC_SIZE ) cond_boundary = arith.cmpi( @@ -635,12 +635,12 @@ def zero_c_tile(tile_n_offset): ) scf.YieldOp([]) - def init_split_k_counter(tile_counter_idx): + def init_split_k_counter(counter_tensor, tile_counter_idx): is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) with ir.InsertionPoint(is_t0_cond_if.then_block): counter_base_ptr = fly.extract_aligned_pointer_as_index( - _ptr_type, fly_values(COUNTER)[0] + _ptr_type, fly_values(counter_tensor)[0] ) counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result counter_byte_offset = arith.index_cast( @@ -674,7 +674,7 @@ def init_split_k_counter(tile_counter_idx): rocdl.s_waitcnt(0) scf.YieldOp([]) - def cleanup_stale_counters_once(): + def cleanup_stale_counters_once(counter_g): clean_cond = arith.cmpi( arith.CmpIPredicate.ult, fx.Index(tid), @@ -685,10 +685,10 @@ def cleanup_stale_counters_once(): clean_counter_idx = fx.Int32( ((signal_state + 2) % 3) * SPLIT_K_COUNTER_MAX_LEN ) + fx.Index(tid) - COUNTER_[fx.Index(clean_counter_idx)] = arith.constant(0, type=T.i32) + counter_g[fx.Index(clean_counter_idx)] = arith.constant(0, type=T.i32) scf.YieldOp([]) - def split_k_barrier(tile_counter_idx): + def split_k_barrier(counter_tensor, tile_counter_idx): init_cur = arith.constant(0, type=T.i32) w = scf.WhileOp([T.i32], [init_cur]) before = ir.Block.create_at_start(w.before, [T.i32]) @@ -701,7 +701,7 @@ def split_k_barrier(tile_counter_idx): scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(after): counter_base_ptr = fly.extract_aligned_pointer_as_index( - _ptr_type, fly_values(COUNTER)[0] + _ptr_type, fly_values(counter_tensor)[0] ) counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result counter_byte_offset = arith.index_cast( @@ -836,50 +836,6 @@ def ldg_sts_a_async(k_offset, lds_stage): scf.YieldOp([]) scf.YieldOp([]) - def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): - for i in range_constexpr(LDG_REG_B_COUNT_AS): - global_tid = BLOCK_THREADS * i + tid - n_local_idx = global_tid // LDG_B_X_THREADS_AS - k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE - col_in_bytes = k_local_idx * DTYPE_BYTES - col_in_bytes = swizzle_xor16(n_local_idx, col_in_bytes, k_blocks16) - col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) - slot_valid = arith.cmpi( - arith.CmpIPredicate.ult, - fx.Index(global_tid), - fx.Index(LDG_B_TOTAL_VECS_AS), - ) - slot_if = scf.IfOp(slot_valid, results_=[], has_else=False) - with ir.InsertionPoint(slot_if.then_block): - global_offset = B_.linear_offset( - (tile_n_offset + fx.Index(n_local_idx), col_idx) - ) - global_offset = arith.index_cast(T.i32, global_offset * DTYPE_BYTES) - lds_offset = ( - bs_.linear_offset( - (fx.Index(lds_stage), n_local_idx, k_local_idx) - ) - * DTYPE_BYTES - ) - lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") - lds_addr = ( - memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset - ) - lds_addr_ = rocdl.readfirstlane( - T.i64, arith.index_cast(T.i64, lds_addr) - ) - lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) - rocdl.raw_ptr_buffer_load_lds( - B_.rsrc, - lds_ptr, - arith.constant(DMA_BYTES, type=T.i32), - global_offset, - arith.constant(0, type=T.i32), - arith.constant(0, type=T.i32), - arith.constant(1, type=T.i32), - ) - scf.YieldOp([]) - def lds_matrix_a(lds_stage): s = fx.Index(lds_stage) a_frags = [0] * A_FRAGS_LEN @@ -899,25 +855,6 @@ def lds_matrix_a(lds_stage): a_frags[kk * WARP_M_STEPS + ii] = vec return a_frags - def lds_matrix_b(lds_stage): - s = fx.Index(lds_stage) - b_frags = [0] * B_FRAGS_LEN - for ii in range_constexpr(WARP_N_STEPS): - warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N - for kk in range_constexpr(WARP_K_STEPS): - warp_atom_k_idx = kk * WARP_ATOM_K - row = warp_atom_n_idx + ldmatrix_b_n_idx - col_in_bytes = ( - warp_atom_k_idx + ldmatrix_b_k_vec_idx - ) * DTYPE_BYTES - col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) - vec = bs_.vec_load( - (s, row, col_in_bytes // DTYPE_BYTES), - WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, - ) - b_frags[kk * WARP_N_STEPS + ii] = vec - return b_frags - def ldg_matrix_b(k_offset, tile_n_offset): vecs = [] for kk in range_constexpr(WARP_K_STEPS): @@ -1015,7 +952,7 @@ def store_split_k_tile(tile_n_offset): ) scf.YieldOp([]) - def store_c_tile(tile_n_offset): + def store_c_tile(bias_g, tile_n_offset): for i in range_constexpr(LDG_REG_C_COUNT): global_tid = BLOCK_THREADS * i + tid m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) @@ -1028,7 +965,7 @@ def store_c_tile(tile_n_offset): with ir.InsertionPoint(cond_boundary_if.then_block): vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) if HAS_BIAS: - bias_vec = BIAS_.vec_load( + bias_vec = bias_g.vec_load( (tile_n_offset + n_local_idx,), LDG_VEC_SIZE ) vec = vec + bias_vec @@ -1067,7 +1004,7 @@ def write_c_frags_to_lds(tile_c_frags_): tile_actives[tile_i], results_=[], has_else=False ) with ir.InsertionPoint(tile_init_if.then_block): - zero_c_tile(tile_n_offsets[tile_i]) + zero_c_tile(BIAS_, tile_n_offsets[tile_i]) scf.YieldOp([]) scf.YieldOp([]) rocdl.sched_barrier(0) @@ -1080,7 +1017,7 @@ def write_c_frags_to_lds(tile_c_frags_): tile_actives[tile_i], results_=[], has_else=False ) with ir.InsertionPoint(tile_init_if.then_block): - init_split_k_counter(tile_counter_indices[tile_i]) + init_split_k_counter(COUNTER, tile_counter_indices[tile_i]) scf.YieldOp([]) scf.YieldOp([]) rocdl.sched_barrier(0) @@ -1088,26 +1025,92 @@ def write_c_frags_to_lds(tile_c_frags_): cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): - cleanup_stale_counters_once() + cleanup_stale_counters_once(COUNTER_) scf.YieldOp([]) rocdl.sched_barrier(0) gpu.barrier() if B_TO_LDS: + def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): + for i in range_constexpr(LDG_REG_B_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS_AS + k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(n_local_idx, col_in_bytes, k_blocks16) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + slot_valid = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Index(global_tid), + fx.Index(LDG_B_TOTAL_VECS_AS), + ) + slot_if = scf.IfOp(slot_valid, results_=[], has_else=False) + with ir.InsertionPoint(slot_if.then_block): + global_offset = B_.linear_offset( + (tile_n_offset + fx.Index(n_local_idx), col_idx) + ) + global_offset = arith.index_cast( + T.i32, global_offset * DTYPE_BYTES + ) + lds_offset = ( + bs_.linear_offset( + (fx.Index(lds_stage), n_local_idx, k_local_idx) + ) + * DTYPE_BYTES + ) + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = ( + memref.extract_aligned_pointer_as_index(bs_.memptr) + + lds_offset + ) + lds_addr_ = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + rocdl.raw_ptr_buffer_load_lds( + B_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + scf.YieldOp([]) + + def lds_matrix_b(lds_stage): + s = fx.Index(lds_stage) + b_frags = [0] * B_FRAGS_LEN + for ii in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_n_idx + ldmatrix_b_n_idx + col_in_bytes = ( + warp_atom_k_idx + ldmatrix_b_k_vec_idx + ) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + vec = bs_.vec_load( + (s, row, col_in_bytes // DTYPE_BYTES), + WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, + ) + b_frags[kk * WARP_N_STEPS + ii] = vec + return b_frags + def run_b_to_lds_tile(tile_n_offset, tile_counter_idx): c_frags_local = [acc_init] * C_FRAGS_LEN if IS_SPLIT_K: cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): - zero_c_tile(tile_n_offset) + zero_c_tile(BIAS_, tile_n_offset) scf.YieldOp([]) rocdl.sched_barrier(0) gpu.barrier() cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): - init_split_k_counter(tile_counter_idx) + init_split_k_counter(COUNTER, tile_counter_idx) scf.YieldOp([]) rocdl.sched_barrier(0) gpu.barrier() @@ -1202,10 +1205,10 @@ def hot_loop_scheduler(): write_c_frags_to_lds(c_frags_local) gpu.barrier() if IS_SPLIT_K: - split_k_barrier(tile_counter_idx) + split_k_barrier(COUNTER, tile_counter_idx) store_split_k_tile(tile_n_offset) else: - store_c_tile(tile_n_offset) + store_c_tile(BIAS_, tile_n_offset) gpu.barrier() for tile_i in range_constexpr(tile_group): @@ -1339,10 +1342,10 @@ def hot_loop_scheduler(): write_c_frags_to_lds(tile_c_frags[tile_i]) gpu.barrier() if IS_SPLIT_K: - split_k_barrier(tile_counter_indices[tile_i]) + split_k_barrier(COUNTER, tile_counter_indices[tile_i]) store_split_k_tile(tile_n_offsets[tile_i]) else: - store_c_tile(tile_n_offsets[tile_i]) + store_c_tile(BIAS_, tile_n_offsets[tile_i]) gpu.barrier() scf.YieldOp([]) diff --git a/aiter/ops/flydsl/kernels/splitk_hgemm.py b/aiter/ops/flydsl/kernels/splitk_hgemm.py index 6c4593441b..471edc1aa5 100644 --- a/aiter/ops/flydsl/kernels/splitk_hgemm.py +++ b/aiter/ops/flydsl/kernels/splitk_hgemm.py @@ -234,8 +234,7 @@ def hgemm_kernel( A_ = GTensor(A, dtype=dtype_, shape=(-1, k)) B_ = GTensor(B, dtype=dtype_, shape=(n, k)) C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) - if HAS_BIAS: - BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) + BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) base_ptr = allocator.get_base() smem_a_ptr = SmemPtr( base_ptr, smem_a_offset, dtype_, shape=(STAGES * BLOCK_M * BLOCK_K,) @@ -262,8 +261,9 @@ def hgemm_kernel( LDG_VEC_SIZE, ), ) - if const_expr(IS_SPLIT_K): - COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) + else: + SHUFFLED_B_ = GTensor(B, dtype=dtype_, shape=(n, k)) + COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) tid = fx.Int32(fx.thread_idx.x) wid = tid // WARP_SIZE @@ -292,7 +292,7 @@ def hgemm_kernel( C_FRAGS_LEN = WARP_M_STEPS * WARP_N_STEPS c_frags = [acc_init] * C_FRAGS_LEN - def zero_c(): + def zero_c(bias_g, counter_tensor, counter_g): cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): @@ -304,7 +304,7 @@ def zero_c(): row_idx = m_offset + fx.Index(m_local_idx) init_vec = zero_vec if HAS_BIAS: - init_vec = BIAS_.vec_load( + init_vec = bias_g.vec_load( (n_offset + n_local_idx,), LDG_VEC_SIZE ) cond_boundary = arith.cmpi( @@ -330,7 +330,7 @@ def zero_c(): is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) with ir.InsertionPoint(is_t0_cond_if.then_block): counter_base_ptr = fly.extract_aligned_pointer_as_index( - _ptr_type, fly_values(COUNTER)[0] + _ptr_type, fly_values(counter_tensor)[0] ) counter_base_ptr = llvm.PtrToIntOp( _i64_type, counter_base_ptr @@ -381,7 +381,7 @@ def zero_c(): ) * SPLIT_K_COUNTER_MAX_LEN ) + fx.Index(tid) - COUNTER_[fx.Index(clean_counter_idx)] = arith.constant( + counter_g[fx.Index(clean_counter_idx)] = arith.constant( 0, type=T.i32 ) scf.YieldOp([]) @@ -389,7 +389,7 @@ def zero_c(): rocdl.sched_barrier(0) gpu.barrier() - def split_k_barrier(): + def split_k_barrier(counter_tensor): init_cur = arith.constant(0, type=T.i32) w = scf.WhileOp([T.i32], [init_cur]) before = ir.Block.create_at_start(w.before, [T.i32]) @@ -402,7 +402,7 @@ def split_k_barrier(): scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(after): counter_base_ptr = fly.extract_aligned_pointer_as_index( - _ptr_type, fly_values(COUNTER)[0] + _ptr_type, fly_values(counter_tensor)[0] ) counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result counter_byte_offset = arith.index_cast( @@ -476,20 +476,6 @@ def ldg_b(k_offset): vecs.append(B_.vec_load((safe_row_idx, col_idx), LDG_VEC_SIZE)) return vecs - def sts_b(vecs, lds_stage): - for i in range_constexpr(LDG_REG_B_COUNT): - global_tid = BLOCK_THREADS * i + tid - n_local_idx = global_tid // LDG_B_X_THREADS - k_local_idx = global_tid % LDG_B_X_THREADS * LDG_VEC_SIZE - col_in_bytes = swizzle_xor16( - n_local_idx, k_local_idx * DTYPE_BYTES, k_blocks16 - ) - bs_.vec_store( - (fx.Index(lds_stage), n_local_idx, col_in_bytes // DTYPE_BYTES), - vecs[i], - LDG_VEC_SIZE, - ) - def ldg_sts_a_async(k_offset, lds_stage): for i in range_constexpr(LDG_REG_A_COUNT_AS): global_tid = BLOCK_THREADS * i + tid @@ -529,45 +515,6 @@ def ldg_sts_a_async(k_offset, lds_stage): arith.constant(1, type=T.i32), ) - def ldg_sts_b_async(k_offset, lds_stage): - for i in range_constexpr(LDG_REG_B_COUNT_AS): - global_tid = BLOCK_THREADS * i + tid - n_local_idx = global_tid // LDG_B_X_THREADS_AS - k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE - col_in_bytes = swizzle_xor16( - n_local_idx, k_local_idx * DTYPE_BYTES, k_blocks16 - ) - row_idx = n_offset + fx.Index(n_local_idx) - safe_row_idx = arith.select( - arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(n)), - row_idx, - fx.Index(0), - ) - col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) - global_offset = B_.linear_offset((safe_row_idx, col_idx)) * DTYPE_BYTES - global_offset = arith.index_cast(T.i32, global_offset) - lds_offset = ( - bs_.linear_offset((fx.Index(lds_stage), n_local_idx, k_local_idx)) - * DTYPE_BYTES - ) - lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") - lds_addr = ( - memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset - ) - lds_addr_ = rocdl.readfirstlane( - T.i64, arith.index_cast(T.i64, lds_addr) - ) - lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) - rocdl.raw_ptr_buffer_load_lds( - B_.rsrc, - lds_ptr, - arith.constant(DMA_BYTES, type=T.i32), - global_offset, - arith.constant(0, type=T.i32), - arith.constant(0, type=T.i32), - arith.constant(1, type=T.i32), - ) - def lds_matrix_a(lds_stage): s = fx.Index(lds_stage) a_frags = [0] * (WARP_K_STEPS * WARP_M_STEPS) @@ -586,24 +533,6 @@ def lds_matrix_a(lds_stage): ) return a_frags - def lds_matrix_b(lds_stage): - s = fx.Index(lds_stage) - b_frags = [0] * (WARP_K_STEPS * WARP_N_STEPS) - for ii in range_constexpr(WARP_N_STEPS): - warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N - for kk in range_constexpr(WARP_K_STEPS): - warp_atom_k_idx = kk * WARP_ATOM_K - row = warp_atom_n_idx + ldmatrix_b_n_idx - col_in_bytes = ( - warp_atom_k_idx + ldmatrix_b_k_vec_idx - ) * DTYPE_BYTES - col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) - b_frags[kk * WARP_N_STEPS + ii] = bs_.vec_load( - (s, row, col_in_bytes // DTYPE_BYTES), - WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, - ) - return b_frags - def ldg_matrix_b(k_offset): vecs = [] b_n_intra_base = ldmatrix_b_n_idx @@ -683,9 +612,70 @@ def block_mma_sync(a_frags, b_frags, c_frags): return c_frags_new if const_expr(IS_SPLIT_K): - zero_c() + zero_c(BIAS_, COUNTER, COUNTER_) if const_expr(B_TO_LDS): + + def ldg_sts_b_async(k_offset, lds_stage): + for i in range_constexpr(LDG_REG_B_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS_AS + k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = swizzle_xor16( + n_local_idx, k_local_idx * DTYPE_BYTES, k_blocks16 + ) + row_idx = n_offset + fx.Index(n_local_idx) + safe_row_idx = arith.select( + arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(n)), + row_idx, + fx.Index(0), + ) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + global_offset = ( + B_.linear_offset((safe_row_idx, col_idx)) * DTYPE_BYTES + ) + global_offset = arith.index_cast(T.i32, global_offset) + lds_offset = ( + bs_.linear_offset( + (fx.Index(lds_stage), n_local_idx, k_local_idx) + ) + * DTYPE_BYTES + ) + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = ( + memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset + ) + lds_addr_ = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + rocdl.raw_ptr_buffer_load_lds( + B_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + + def lds_matrix_b(lds_stage): + s = fx.Index(lds_stage) + b_frags = [0] * (WARP_K_STEPS * WARP_N_STEPS) + for ii in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_n_idx + ldmatrix_b_n_idx + col_in_bytes = ( + warp_atom_k_idx + ldmatrix_b_k_vec_idx + ) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + b_frags[kk * WARP_N_STEPS + ii] = bs_.vec_load( + (s, row, col_in_bytes // DTYPE_BYTES), + WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, + ) + return b_frags ldg_sts_a_async(ks_begin, 0) ldg_sts_b_async(ks_begin, 0) gpu.barrier() @@ -795,11 +785,12 @@ def hot_loop_scheduler(): b_frags = state[2 + C_FRAGS_LEN + A_FRAGS_LEN :] if const_expr(ASYNC_COPY): ldg_sts_a_async(k_offset + BLOCK_K, next_stage) + b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) + c_frags = block_mma_sync(a_frags, b_frags, c_frags) else: a_regs_next = ldg_a(k_offset + BLOCK_K) - b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) - c_frags = block_mma_sync(a_frags, b_frags, c_frags) - if const_expr(not ASYNC_COPY): + b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) + c_frags = block_mma_sync(a_frags, b_frags, c_frags) sts_a(a_regs_next, next_stage) hot_loop_scheduler() gpu.barrier() @@ -832,7 +823,7 @@ def hot_loop_scheduler(): cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) if const_expr(IS_SPLIT_K): - split_k_barrier() + split_k_barrier(COUNTER) out_raw = fly_values(C)[0] out_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, out_raw) out_base_int = llvm.PtrToIntOp(_i64_type, out_base_ptr).result From 28d5bd09c607c9c343e48d5d04fc96e80f2be988 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 20 Apr 2026 03:46:38 -0500 Subject: [PATCH 09/11] recover 0 from SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS --- aiter/ops/flydsl/kernels/small_m_hgemm.py | 24 ++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index fda67e801f..e2807c80fe 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -95,9 +95,9 @@ 1024, ) SMALL_M_NON_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) -# Avoid catalog entries that normalize to the same kernel (hint 0 becomes 2 / 8 -# for wide-N B_TO_LDS in compile_small_m_hgemm_kernel). -SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = (2, 4) +# Keep 0 for narrow B_TO_LDS shapes where it remains a real candidate, and +# canonicalize only the wide-N B_TO_LDS duplicates at registry emission time. +SMALL_M_B_TO_LDS_WAVES_PER_EU_OPTIONS = (0, 2, 4) SMALL_M_B_TO_LDS_UNROLL_OPTIONS = (8, 16) SMALL_M_N_TILE_REPEAT_OPTIONS = (1, 2, 4) SMALL_M_PERSISTENT_N_TILE_OPTIONS = (2, 4, 8) @@ -299,6 +299,23 @@ def add_variant( return tuple(variants) +def _canonicalize_small_m_registry_config(config: dict) -> dict: + """Match registry metadata to the effective compile-time kernel settings.""" + canonical = dict(config) + wide_n_b_to_lds = ( + canonical["b_to_lds"] + and canonical["n_tile_repeat"] == 1 + and canonical["tile_n"] >= 128 + and canonical["block_n_warps"] >= 2 + ) + if canonical["b_to_lds"]: + if canonical["b_to_lds_unroll"] <= 0: + canonical["b_to_lds_unroll"] = 8 + if canonical["waves_per_eu"] <= 0 and wide_n_b_to_lds: + canonical["waves_per_eu"] = 2 + return canonical + + def iter_small_m_registry_configs( dtype: str, out_dtype: str, @@ -360,6 +377,7 @@ def iter_small_m_registry_configs( ) except ValueError: continue + config = _canonicalize_small_m_registry_config(config) config_key = tuple(sorted(config.items())) if config_key in seen_configs: continue From b8df1e4d4516a6d9010868ca0b9ec3e4be4cd7c9 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 20 Apr 2026 05:14:15 -0500 Subject: [PATCH 10/11] update gemm code --- aiter/ops/flydsl/kernels/small_m_hgemm.py | 40 ++++++++++++----------- aiter/ops/flydsl/kernels/splitk_hgemm.py | 12 ++++--- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index e2807c80fe..9425b3f04f 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -46,7 +46,7 @@ from flydsl._mlir.dialects import fly, llvm, memref, scf from flydsl.compiler.kernel_function import CompilationContext from flydsl.compiler.protocol import fly_values -from flydsl.expr import arith, gpu, range_constexpr, rocdl, vector +from flydsl.expr import arith, const_expr, gpu, range_constexpr, rocdl, vector from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -478,10 +478,12 @@ def compile_small_m_hgemm_kernel( ) WAVES_PER_EU = ( int(WAVES_PER_EU_HINT) - if WAVES_PER_EU_HINT > 0 - else (2 if WIDE_N_B_TO_LDS else 0) + if const_expr(WAVES_PER_EU_HINT > 0) + else (2 if const_expr(WIDE_N_B_TO_LDS) else 0) + ) + EFFECTIVE_B_TO_LDS_UNROLL = ( + int(B_TO_LDS_UNROLL) if const_expr(B_TO_LDS_UNROLL > 0) else 8 ) - EFFECTIVE_B_TO_LDS_UNROLL = int(B_TO_LDS_UNROLL) if B_TO_LDS_UNROLL > 0 else 8 BLOCK_MK_SIZE = BLOCK_M * BLOCK_K BLOCK_NK_SIZE = BLOCK_N * BLOCK_K @@ -532,7 +534,7 @@ def compile_small_m_hgemm_kernel( n_tile_repeat=N_TILE_REPEAT, persistent_n_tiles=PERSISTENT_N_TILES, waves_per_eu=WAVES_PER_EU, - b_to_lds_unroll=EFFECTIVE_B_TO_LDS_UNROLL if B_TO_LDS else 0, + b_to_lds_unroll=EFFECTIVE_B_TO_LDS_UNROLL if const_expr(B_TO_LDS) else 0, b_to_lds=B_TO_LDS, has_bias=HAS_BIAS, ) @@ -568,7 +570,7 @@ def small_m_hgemm_kernel( shape=(STAGES * BLOCK_M * BLOCK_K,), ) as_ = STensor(smem_a_ptr, dtype_, shape=(STAGES, BLOCK_M, BLOCK_K)) - if B_TO_LDS: + if const_expr(B_TO_LDS): smem_b_ptr = SmemPtr( base_ptr, smem_b_offset, @@ -590,7 +592,7 @@ def small_m_hgemm_kernel( ks_idx = fx.Index(fx.block_idx.z) ks_begin = arith.index_cast(T.i32, ks_idx * ks) block_n_tiles = n // BLOCK_N - tile_group = PERSISTENT_N_TILES if PERSISTENT_N else N_TILE_REPEAT + tile_group = PERSISTENT_N_TILES if const_expr(PERSISTENT_N) else N_TILE_REPEAT m_offset = fx.Index(block_m_idx * BLOCK_M) tile_block_n_indices = [ @@ -639,7 +641,7 @@ def zero_c_tile(bias_g, tile_n_offset): n_local_idx = global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE row_idx = m_offset + fx.Index(m_local_idx) init_vec = zero_vec - if HAS_BIAS: + if const_expr(HAS_BIAS): init_vec = bias_g.vec_load( (tile_n_offset + n_local_idx,), LDG_VEC_SIZE ) @@ -888,7 +890,7 @@ def ldg_matrix_b(k_offset, tile_n_offset): return vecs def maybe_ldg_matrix_b(k_offset, tile_n_offset, tile_active): - if N_TILE_REPEAT == 1: + if const_expr(N_TILE_REPEAT == 1): return ldg_matrix_b(k_offset, tile_n_offset) load_if = scf.IfOp( tile_active, @@ -982,7 +984,7 @@ def store_c_tile(bias_g, tile_n_offset): cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) - if HAS_BIAS: + if const_expr(HAS_BIAS): bias_vec = bias_g.vec_load( (tile_n_offset + n_local_idx,), LDG_VEC_SIZE ) @@ -1012,9 +1014,9 @@ def write_c_frags_to_lds(tile_c_frags_): ) cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) - if not B_TO_LDS: + if const_expr(not B_TO_LDS): cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): for tile_i in range_constexpr(N_TILE_REPEAT): @@ -1048,7 +1050,7 @@ def write_c_frags_to_lds(tile_c_frags_): rocdl.sched_barrier(0) gpu.barrier() - if B_TO_LDS: + if const_expr(B_TO_LDS): def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): for i in range_constexpr(LDG_REG_B_COUNT_AS): @@ -1118,7 +1120,7 @@ def lds_matrix_b(lds_stage): def run_b_to_lds_tile(tile_n_offset, tile_counter_idx): c_frags_local = [acc_init] * C_FRAGS_LEN - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): zero_c_tile(BIAS_, tile_n_offset) @@ -1142,7 +1144,7 @@ def hot_loop_scheduler(): WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K ) LDG_TOTAL = LDG_REG_A_COUNT_AS + LDG_REG_B_COUNT_AS - if WIDE_N_B_TO_LDS: + if const_expr(WIDE_N_B_TO_LDS): for _ in range_constexpr(WARP_K_STEPS * WARP_M_STEPS): rocdl.sched_dsrd(1) for _ in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): @@ -1222,7 +1224,7 @@ def hot_loop_scheduler(): write_c_frags_to_lds(c_frags_local) gpu.barrier() - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): split_k_barrier(COUNTER, tile_counter_idx) store_split_k_tile(tile_n_offset) else: @@ -1359,7 +1361,7 @@ def hot_loop_scheduler(): with ir.InsertionPoint(tile_store_if.then_block): write_c_frags_to_lds(tile_c_frags[tile_i]) gpu.barrier() - if IS_SPLIT_K: + if const_expr(IS_SPLIT_K): split_k_barrier(COUNTER, tile_counter_indices[tile_i]) store_split_k_tile(tile_n_offsets[tile_i]) else: @@ -1382,7 +1384,7 @@ def launch_small_m_hgemm_kernel( ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): allocator.finalize() - if WAVES_PER_EU > 0: + if const_expr(WAVES_PER_EU > 0): for op in ctx.gpu_module_body.operations: if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( @@ -1390,7 +1392,7 @@ def launch_small_m_hgemm_kernel( ) bm = (m + BLOCK_M - 1) // BLOCK_M - tile_group = PERSISTENT_N_TILES if PERSISTENT_N else N_TILE_REPEAT + tile_group = PERSISTENT_N_TILES if const_expr(PERSISTENT_N) else N_TILE_REPEAT bn = (n // BLOCK_N + tile_group - 1) // tile_group small_m_hgemm_kernel._func.__name__ = KERNEL_NAME small_m_hgemm_kernel(C, A, B, BIAS, m, COUNTER, signal_state).launch( diff --git a/aiter/ops/flydsl/kernels/splitk_hgemm.py b/aiter/ops/flydsl/kernels/splitk_hgemm.py index 471edc1aa5..e24e374c4f 100644 --- a/aiter/ops/flydsl/kernels/splitk_hgemm.py +++ b/aiter/ops/flydsl/kernels/splitk_hgemm.py @@ -303,7 +303,7 @@ def zero_c(bias_g, counter_tensor, counter_g): n_local_idx = global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE row_idx = m_offset + fx.Index(m_local_idx) init_vec = zero_vec - if HAS_BIAS: + if const_expr(HAS_BIAS): init_vec = bias_g.vec_load( (n_offset + n_local_idx,), LDG_VEC_SIZE ) @@ -686,11 +686,11 @@ def hot_loop_scheduler(): for _ in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): rocdl.sched_dsrd(1) for _ in range_constexpr( - LDG_REG_A_COUNT_AS if ASYNC_COPY else LDG_REG_A_COUNT + LDG_REG_A_COUNT_AS if const_expr(ASYNC_COPY) else LDG_REG_A_COUNT ): rocdl.sched_vmem(1) for _ in range_constexpr( - LDG_REG_B_COUNT_AS if ASYNC_COPY else LDG_REG_B_COUNT + LDG_REG_B_COUNT_AS if const_expr(ASYNC_COPY) else LDG_REG_B_COUNT ): rocdl.sched_vmem(1) for _ in range_constexpr( @@ -753,7 +753,9 @@ def hot_loop_scheduler(): mfma_total = ( WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K ) - ldg_reg_a_count_ = LDG_REG_A_COUNT_AS if ASYNC_COPY else LDG_REG_A_COUNT + ldg_reg_a_count_ = ( + LDG_REG_A_COUNT_AS if const_expr(ASYNC_COPY) else LDG_REG_A_COUNT + ) ldg_total = ldg_reg_a_count_ + WARP_K_STEPS * WARP_N_STEPS mfma_ = OnlineScheduler(mfma_total, mfma_total) ldg_ = OnlineScheduler(ldg_total, ldg_total) @@ -889,7 +891,7 @@ def hot_loop_scheduler(): cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) - if HAS_BIAS: + if const_expr(HAS_BIAS): bias_vec = BIAS_.vec_load( (n_offset + n_local_idx,), LDG_VEC_SIZE ) From ef19b066634c7aca43ca96c9c149ec070682a2bd Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Tue, 21 Apr 2026 02:22:12 -0500 Subject: [PATCH 11/11] refact code --- aiter/ops/flydsl/kernels/small_m_hgemm.py | 57 ++++++++++++----------- aiter/ops/flydsl/kernels/splitk_hgemm.py | 27 ++++++----- 2 files changed, 44 insertions(+), 40 deletions(-) diff --git a/aiter/ops/flydsl/kernels/small_m_hgemm.py b/aiter/ops/flydsl/kernels/small_m_hgemm.py index 9425b3f04f..12f03cef20 100644 --- a/aiter/ops/flydsl/kernels/small_m_hgemm.py +++ b/aiter/ops/flydsl/kernels/small_m_hgemm.py @@ -561,6 +561,7 @@ def small_m_hgemm_kernel( B_ = GTensor(B, dtype=dtype_, shape=(n, k)) C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) + bs_ = None base_ptr = allocator.get_base() smem_a_ptr = SmemPtr( @@ -633,7 +634,7 @@ def small_m_hgemm_kernel( zero_b_frag = vector.broadcast(B_FRAG_T, c_zero_d) c_frags = [acc_init] * (C_FRAGS_LEN * N_TILE_REPEAT) - def zero_c_tile(bias_g, tile_n_offset): + def zero_c_tile(c_g, bias_g, tile_n_offset): zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) for i in range_constexpr(LDG_REG_C_COUNT): global_tid = BLOCK_THREADS * i + tid @@ -650,7 +651,7 @@ def zero_c_tile(bias_g, tile_n_offset): ) cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): - C_.vec_store( + c_g.vec_store( (row_idx, tile_n_offset + n_local_idx), init_vec, LDG_VEC_SIZE ) scf.YieldOp([]) @@ -916,8 +917,8 @@ def block_mma_sync(a_frags, b_frags, c_frags): ) return c_frags_new - def store_split_k_tile(tile_n_offset): - out_raw = fly_values(C)[0] + def store_split_k_tile(c_tensor, c_g, c_s, tile_n_offset): + out_raw = fly_values(c_tensor)[0] out_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, out_raw) out_base_int = llvm.PtrToIntOp(_i64_type, out_base_ptr).result for i in range_constexpr(LDG_REG_C_COUNT): @@ -931,9 +932,9 @@ def store_split_k_tile(tile_n_offset): ) cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): - pk_val = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + pk_val = c_s.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) linear_bytes_offset = ( - C_.linear_offset((m_global_idx, n_global_idx)) * DTYPE_BYTES + c_g.linear_offset((m_global_idx, n_global_idx)) * DTYPE_BYTES ) vec2_ty = T.vec(2, dtype_) for vec_idx in range_constexpr(LDG_VEC_SIZE // 2): @@ -972,7 +973,7 @@ def store_split_k_tile(tile_n_offset): ) scf.YieldOp([]) - def store_c_tile(bias_g, tile_n_offset): + def store_c_tile(bias_g, c_g, c_s, tile_n_offset): for i in range_constexpr(LDG_REG_C_COUNT): global_tid = BLOCK_THREADS * i + tid m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) @@ -983,13 +984,13 @@ def store_c_tile(bias_g, tile_n_offset): ) cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) with ir.InsertionPoint(cond_boundary_if.then_block): - vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + vec = c_s.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) if const_expr(HAS_BIAS): bias_vec = bias_g.vec_load( (tile_n_offset + n_local_idx,), LDG_VEC_SIZE ) vec = vec + bias_vec - C_.vec_store( + c_g.vec_store( (m_global_idx, tile_n_offset + n_local_idx), vec, LDG_VEC_SIZE ) scf.YieldOp([]) @@ -997,7 +998,7 @@ def store_c_tile(bias_g, tile_n_offset): stmatrix_c_m_vec_idx = w_tid // WMMA_N * WMMA_C_FRAG_VALUES stmatrix_c_n_idx = w_tid % WMMA_N - def write_c_frags_to_lds(tile_c_frags_): + def write_c_frags_to_lds(c_s, tile_c_frags_): for ii in range_constexpr(WARP_M_STEPS): warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M for jj in range_constexpr(WARP_N_STEPS): @@ -1012,7 +1013,7 @@ def write_c_frags_to_lds(tile_c_frags_): static_position=[kk], dynamic_position=[], ) - cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) + c_s[lds_m_idx, lds_n_idx] = val.truncf(dtype_) if const_expr(IS_SPLIT_K): cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) @@ -1024,7 +1025,7 @@ def write_c_frags_to_lds(tile_c_frags_): tile_actives[tile_i], results_=[], has_else=False ) with ir.InsertionPoint(tile_init_if.then_block): - zero_c_tile(BIAS_, tile_n_offsets[tile_i]) + zero_c_tile(C_, BIAS_, tile_n_offsets[tile_i]) scf.YieldOp([]) scf.YieldOp([]) rocdl.sched_barrier(0) @@ -1052,7 +1053,7 @@ def write_c_frags_to_lds(tile_c_frags_): if const_expr(B_TO_LDS): - def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): + def ldg_sts_b_async(bs_s, k_offset, lds_stage, tile_n_offset): for i in range_constexpr(LDG_REG_B_COUNT_AS): global_tid = BLOCK_THREADS * i + tid n_local_idx = global_tid // LDG_B_X_THREADS_AS @@ -1074,14 +1075,14 @@ def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): T.i32, global_offset * DTYPE_BYTES ) lds_offset = ( - bs_.linear_offset( + bs_s.linear_offset( (fx.Index(lds_stage), n_local_idx, k_local_idx) ) * DTYPE_BYTES ) lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") lds_addr = ( - memref.extract_aligned_pointer_as_index(bs_.memptr) + memref.extract_aligned_pointer_as_index(bs_s.memptr) + lds_offset ) lds_addr_ = rocdl.readfirstlane( @@ -1099,7 +1100,7 @@ def ldg_sts_b_async(k_offset, lds_stage, tile_n_offset): ) scf.YieldOp([]) - def lds_matrix_b(lds_stage): + def lds_matrix_b(bs_s, lds_stage): s = fx.Index(lds_stage) b_frags = [0] * B_FRAGS_LEN for ii in range_constexpr(WARP_N_STEPS): @@ -1111,7 +1112,7 @@ def lds_matrix_b(lds_stage): warp_atom_k_idx + ldmatrix_b_k_vec_idx ) * DTYPE_BYTES col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) - vec = bs_.vec_load( + vec = bs_s.vec_load( (s, row, col_in_bytes // DTYPE_BYTES), WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, ) @@ -1123,7 +1124,7 @@ def run_b_to_lds_tile(tile_n_offset, tile_counter_idx): if const_expr(IS_SPLIT_K): cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): - zero_c_tile(BIAS_, tile_n_offset) + zero_c_tile(C_, BIAS_, tile_n_offset) scf.YieldOp([]) rocdl.sched_barrier(0) gpu.barrier() @@ -1136,7 +1137,7 @@ def run_b_to_lds_tile(tile_n_offset, tile_counter_idx): gpu.barrier() ldg_sts_a_async(ks_begin, 0) - ldg_sts_b_async(ks_begin, 0, tile_n_offset) + ldg_sts_b_async(bs_, ks_begin, 0, tile_n_offset) gpu.barrier() def hot_loop_scheduler(): @@ -1192,10 +1193,10 @@ def hot_loop_scheduler(): with ir.InsertionPoint(cond_if.then_block): next_stage = 1 - current_stage a_frags = lds_matrix_a(current_stage) - b_frags = lds_matrix_b(current_stage) + b_frags = lds_matrix_b(bs_, current_stage) ldg_sts_a_async(k_offset + BLOCK_K, next_stage) ldg_sts_b_async( - k_offset + BLOCK_K, next_stage, tile_n_offset + bs_, k_offset + BLOCK_K, next_stage, tile_n_offset ) c_frags_new = block_mma_sync( a_frags, b_frags, c_frags_local @@ -1219,16 +1220,16 @@ def hot_loop_scheduler(): current_stage = results[1] c_frags_local = results[2 : 2 + C_FRAGS_LEN] a_frags = lds_matrix_a(current_stage) - b_frags = lds_matrix_b(current_stage) + b_frags = lds_matrix_b(bs_, current_stage) c_frags_local = block_mma_sync(a_frags, b_frags, c_frags_local) - write_c_frags_to_lds(c_frags_local) + write_c_frags_to_lds(cs_, c_frags_local) gpu.barrier() if const_expr(IS_SPLIT_K): split_k_barrier(COUNTER, tile_counter_idx) - store_split_k_tile(tile_n_offset) + store_split_k_tile(C, C_, cs_, tile_n_offset) else: - store_c_tile(BIAS_, tile_n_offset) + store_c_tile(BIAS_, C_, cs_, tile_n_offset) gpu.barrier() for tile_i in range_constexpr(tile_group): @@ -1359,13 +1360,13 @@ def hot_loop_scheduler(): tile_actives[tile_i], results_=[], has_else=False ) with ir.InsertionPoint(tile_store_if.then_block): - write_c_frags_to_lds(tile_c_frags[tile_i]) + write_c_frags_to_lds(cs_, tile_c_frags[tile_i]) gpu.barrier() if const_expr(IS_SPLIT_K): split_k_barrier(COUNTER, tile_counter_indices[tile_i]) - store_split_k_tile(tile_n_offsets[tile_i]) + store_split_k_tile(C, C_, cs_, tile_n_offsets[tile_i]) else: - store_c_tile(BIAS_, tile_n_offsets[tile_i]) + store_c_tile(BIAS_, C_, cs_, tile_n_offsets[tile_i]) gpu.barrier() scf.YieldOp([]) diff --git a/aiter/ops/flydsl/kernels/splitk_hgemm.py b/aiter/ops/flydsl/kernels/splitk_hgemm.py index e24e374c4f..787b71babb 100644 --- a/aiter/ops/flydsl/kernels/splitk_hgemm.py +++ b/aiter/ops/flydsl/kernels/splitk_hgemm.py @@ -235,6 +235,7 @@ def hgemm_kernel( B_ = GTensor(B, dtype=dtype_, shape=(n, k)) C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) BIAS_ = GTensor(BIAS, dtype=dtype_, shape=(n,)) + bs_ = None base_ptr = allocator.get_base() smem_a_ptr = SmemPtr( base_ptr, smem_a_offset, dtype_, shape=(STAGES * BLOCK_M * BLOCK_K,) @@ -292,7 +293,7 @@ def hgemm_kernel( C_FRAGS_LEN = WARP_M_STEPS * WARP_N_STEPS c_frags = [acc_init] * C_FRAGS_LEN - def zero_c(bias_g, counter_tensor, counter_g): + def zero_c(bias_g, c_g, counter_tensor, counter_g): cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) with ir.InsertionPoint(cond_ks0_if.then_block): @@ -314,7 +315,7 @@ def zero_c(bias_g, counter_tensor, counter_g): cond_boundary, results_=[], has_else=False ) with ir.InsertionPoint(cond_boundary_if.then_block): - C_.vec_store( + c_g.vec_store( (row_idx, n_offset + n_local_idx), init_vec, LDG_VEC_SIZE ) scf.YieldOp([]) @@ -612,11 +613,11 @@ def block_mma_sync(a_frags, b_frags, c_frags): return c_frags_new if const_expr(IS_SPLIT_K): - zero_c(BIAS_, COUNTER, COUNTER_) + zero_c(BIAS_, C_, COUNTER, COUNTER_) if const_expr(B_TO_LDS): - def ldg_sts_b_async(k_offset, lds_stage): + def ldg_sts_b_async(bs_s, k_offset, lds_stage): for i in range_constexpr(LDG_REG_B_COUNT_AS): global_tid = BLOCK_THREADS * i + tid n_local_idx = global_tid // LDG_B_X_THREADS_AS @@ -636,14 +637,15 @@ def ldg_sts_b_async(k_offset, lds_stage): ) global_offset = arith.index_cast(T.i32, global_offset) lds_offset = ( - bs_.linear_offset( + bs_s.linear_offset( (fx.Index(lds_stage), n_local_idx, k_local_idx) ) * DTYPE_BYTES ) lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") lds_addr = ( - memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset + memref.extract_aligned_pointer_as_index(bs_s.memptr) + + lds_offset ) lds_addr_ = rocdl.readfirstlane( T.i64, arith.index_cast(T.i64, lds_addr) @@ -659,7 +661,7 @@ def ldg_sts_b_async(k_offset, lds_stage): arith.constant(1, type=T.i32), ) - def lds_matrix_b(lds_stage): + def lds_matrix_b(bs_s, lds_stage): s = fx.Index(lds_stage) b_frags = [0] * (WARP_K_STEPS * WARP_N_STEPS) for ii in range_constexpr(WARP_N_STEPS): @@ -671,13 +673,14 @@ def lds_matrix_b(lds_stage): warp_atom_k_idx + ldmatrix_b_k_vec_idx ) * DTYPE_BYTES col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) - b_frags[kk * WARP_N_STEPS + ii] = bs_.vec_load( + b_frags[kk * WARP_N_STEPS + ii] = bs_s.vec_load( (s, row, col_in_bytes // DTYPE_BYTES), WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K, ) return b_frags + ldg_sts_a_async(ks_begin, 0) - ldg_sts_b_async(ks_begin, 0) + ldg_sts_b_async(bs_, ks_begin, 0) gpu.barrier() def hot_loop_scheduler(): @@ -720,9 +723,9 @@ def hot_loop_scheduler(): with ir.InsertionPoint(cond_if.then_block): next_stage = 1 - current_stage a_frags = lds_matrix_a(current_stage) - b_frags = lds_matrix_b(current_stage) + b_frags = lds_matrix_b(bs_, current_stage) ldg_sts_a_async(k_offset + BLOCK_K, next_stage) - ldg_sts_b_async(k_offset + BLOCK_K, next_stage) + ldg_sts_b_async(bs_, k_offset + BLOCK_K, next_stage) c_frags_new = block_mma_sync(a_frags, b_frags, c_frags) hot_loop_scheduler() gpu.barrier() @@ -740,7 +743,7 @@ def hot_loop_scheduler(): current_stage = results[1] c_frags = results[2 : 2 + C_FRAGS_LEN] a_frags = lds_matrix_a(current_stage) - b_frags = lds_matrix_b(current_stage) + b_frags = lds_matrix_b(bs_, current_stage) c_frags = block_mma_sync(a_frags, b_frags, c_frags) else: sts_a(ldg_a(ks_begin), 0)