Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 42 additions & 59 deletions aiter/aot/flydsl/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from aiter.aot.flydsl.common import collect_aot_jobs, compile_only_env, job_identity
from aiter.jit.core import AITER_CONFIGS
from aiter.jit.utils.chip_info import get_gfx
from aiter.ops.flydsl.gemm_kernels import get_flydsl_splitk_hgemm_kernel_params
from aiter.ops.flydsl.kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel
from aiter.ops.flydsl.kernels.preshuffle_gemm import compile_preshuffle_gemm_a8
from aiter.ops.flydsl.kernels.splitk_hgemm import compile_hgemm_kernel

# Keep the default AOT coverage aligned with runtime config resolution.
DEFAULT_CSVS = [
Expand All @@ -61,19 +62,6 @@
r"(?P<lds_stage>\d+)x(?P<cshuffle>\d+)x(?P<async_copy>\d+)x(?P<waves_per_eu>\d+)_"
r"(?P<scheduler>[A-Za-z0-9_]+)$"
)
_HGEMM_RE = re.compile(
r"^flydsl_gemm(?P<stage>\d+)_"
r"a(?P<a_dtype>[a-z0-9]+)_w(?P<w_dtype>[a-z0-9]+)_(?P<out_dtype>[a-z0-9]+)_"
r"t(?P<tile_m>\d+)x(?P<tile_n>\d+)x(?P<tile_k>\d+)_"
r"split_k(?P<split_k>\d+)_"
r"block_m_warp(?P<block_m_warps>\d+)_"
r"block_n_warp(?P<block_n_warps>\d+)_"
r"async_copy(?P<async_copy>True|False)_"
r"b_to_lds(?P<b_to_lds>True|False)_"
r"b_preshuffle(?P<b_preshuffle>True|False)_"
r"c_to_lds(?P<c_to_lds>True|False)_"
r"(?P<target_gfx>gfx[0-9a-z]+)$"
)
_SHORT_DTYPE = {
"F8": "fp8",
"I8": "int8",
Expand All @@ -82,10 +70,15 @@
}


def _parse_bool(value: str) -> bool:
if value == "True":
def _parse_bool(value: Optional[str]) -> bool:
if value is None:
return False
normalized = value.strip().lower()
if normalized == "":
return False
if normalized in {"1", "true", "yes"}:
return True
if value == "False":
if normalized in {"0", "false", "no"}:
return False
raise ValueError(f"Expected True/False, got {value!r}")

Expand Down Expand Up @@ -120,37 +113,6 @@ def _parse_preshuffle_kernel_name(name: str) -> Optional[Dict]:
}


def _parse_hgemm_kernel_name(name: str) -> Optional[Dict]:
m = _HGEMM_RE.fullmatch(name)
if m is None:
return None

a_dtype = m.group("a_dtype")
w_dtype = m.group("w_dtype")
if a_dtype != w_dtype:
raise ValueError(
f"Unsupported mixed HGEMM input dtypes in {name!r}: {a_dtype} vs {w_dtype}"
)

return {
"kind": "hgemm",
"stage": int(m.group("stage")),
"dtype": a_dtype,
"out_dtype": m.group("out_dtype"),
"tile_m": int(m.group("tile_m")),
"tile_n": int(m.group("tile_n")),
"tile_k": int(m.group("tile_k")),
"split_k": int(m.group("split_k")),
"block_m_warps": int(m.group("block_m_warps")),
"block_n_warps": int(m.group("block_n_warps")),
"async_copy": _parse_bool(m.group("async_copy")),
"b_to_lds": _parse_bool(m.group("b_to_lds")),
"b_preshuffle": _parse_bool(m.group("b_preshuffle")),
"c_to_lds": _parse_bool(m.group("c_to_lds")),
"target_gfx": m.group("target_gfx"),
}


def parse_csv(csv_path: str):
"""Parse a GEMM tuned CSV and return a list of unique FlyDSL compile jobs."""
jobs = []
Expand All @@ -171,7 +133,10 @@ def parse_csv(csv_path: str):
if kernel_name.startswith("flydsl_bpreshuflle_"):
params = _parse_preshuffle_kernel_name(kernel_name)
elif kernel_name.startswith("flydsl_gemm"):
params = _parse_hgemm_kernel_name(kernel_name)
params = get_flydsl_splitk_hgemm_kernel_params(kernel_name)
if params is not None:
params = dict(params)
params["kind"] = "hgemm"
else:
params = None

Expand All @@ -186,6 +151,7 @@ def parse_csv(csv_path: str):
"m": m,
"n": n,
"k": k,
"has_bias": _parse_bool(row.get("bias")),
**params,
}
key = job_identity(job)
Expand Down Expand Up @@ -235,11 +201,17 @@ def _compile_hgemm_to_cache(
split_k: int,
block_m_warps: int,
block_n_warps: int,
n_tile_repeat: int = 1,
persistent_n_tiles: int = 1,
waves_per_eu: int = 0,
b_to_lds_unroll: int = 0,
async_copy: bool,
b_to_lds: bool,
b_preshuffle: bool,
c_to_lds: bool,
target_gfx: str,
kernel_family: str = "hgemm",
has_bias: bool = False,
**kwargs,
):
del kwargs, out_dtype
Expand All @@ -259,27 +231,38 @@ def _compile_hgemm_to_cache(
out = torch.empty((m, n), device=dev, dtype=torch_dtype)
a = torch.empty((m, k), device=dev, dtype=torch_dtype)
b = torch.empty((n, k), device=dev, dtype=torch_dtype)
bias = torch.empty((n,), device=dev, dtype=torch_dtype)
counter = torch.zeros(
(128 * 3,),
device=dev,
dtype=torch.int32,
)
stream = torch.cuda.current_stream(device=dev)

exe = compile_hgemm_kernel(
exe = compile_flydsl_hgemm_kernel(
dtype,
n,
k,
TILE_M=tile_m,
TILE_N=tile_n,
TILE_K=tile_k,
SPLIT_K=split_k,
BLOCK_M_WARPS=block_m_warps,
BLOCK_N_WARPS=block_n_warps,
B_PRE_SHUFFLE=b_preshuffle,
B_TO_LDS=b_to_lds,
kernel_family=kernel_family,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
split_k=split_k,
block_m_warps=block_m_warps,
block_n_warps=block_n_warps,
n_tile_repeat=n_tile_repeat,
persistent_n_tiles=persistent_n_tiles,
waves_per_eu=waves_per_eu,
b_to_lds_unroll=b_to_lds_unroll,
async_copy=async_copy,
b_to_lds=b_to_lds,
b_preshuffle=b_preshuffle,
c_to_lds=c_to_lds,
has_bias=has_bias,
)
_compile_executable_to_cache(exe, out, a, b, m, counter, 0, stream)
# FlyDSL JIT does not accept None for tensor slots; pass a real buffer when
# bias fusion is disabled (matches runtime launcher dummy tensor behavior).
_compile_executable_to_cache(exe, out, a, b, bias, m, counter, 0, stream)


def _compile_preshuffle_to_cache(
Expand Down
Loading
Loading