Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 14 additions & 45 deletions aiter/ops/flydsl/gemm_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _validate_hgemm_tiling(
)
if pack_n != 1:
raise ValueError(
"Current kernel only supports `pack_n=1`; " f"got pack_n={pack_n}"
f"Current kernel only supports `pack_n=1`; got pack_n={pack_n}"
)

warp_atom_m = 16
Expand Down Expand Up @@ -664,7 +664,6 @@ def flydsl_hgemm(

_flydsl_compile_fn = None
_flydsl_import_done = False
_flydsl_kernel_cache: dict = {}


def _get_compile_fn():
Expand Down Expand Up @@ -702,7 +701,7 @@ def flydsl_preshuffle_gemm_a8(
use_async_copy: int = 0,
waves_per_eu: int = 0,
) -> Tensor:
"""Compile (cached) and run a FlyDSL preshuffle GEMM kernel."""
"""Compile (cached via lru_cache) and run a FlyDSL preshuffle GEMM kernel."""
compile_fn = _get_compile_fn()
if compile_fn is None:
raise RuntimeError("[FlyDSL] compile function not available")
Expand Down Expand Up @@ -739,49 +738,19 @@ def flydsl_preshuffle_gemm_a8(
f"[FlyDSL] unsupported output dtype {Out.dtype}; expected torch.bfloat16 or torch.float16"
)

cache_key = (
m,
n,
k,
in_dtype,
out_dtype,
tile_m,
tile_n,
tile_k,
lds_stage,
use_cshuffle_epilog,
use_async_copy,
wpe,
exe = compile_fn(
N=n,
K=k,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
in_dtype=in_dtype,
out_dtype=out_dtype,
lds_stage=lds_stage,
use_cshuffle_epilog=bool(use_cshuffle_epilog),
use_async_copy=bool(use_async_copy),
waves_per_eu=wpe,
)
if cache_key not in _flydsl_kernel_cache:
try:
exe = compile_fn(
M=m,
N=n,
K=k,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
in_dtype=in_dtype,
out_dtype=out_dtype,
lds_stage=lds_stage,
use_cshuffle_epilog=bool(use_cshuffle_epilog),
use_async_copy=bool(use_async_copy),
waves_per_eu=wpe,
)
_flydsl_kernel_cache[cache_key] = exe
logger.info(
f"[FlyDSL] compiled preshuffle GEMM ({m},{n},{k} {in_dtype} "
f"tile={tile_m}x{tile_n}x{tile_k} lds={lds_stage} csh={use_cshuffle_epilog} "
f"acp={use_async_copy} wpe={waves_per_eu})"
)
except Exception as e:
logger.warning(f"[FlyDSL] compile failed ({m},{n},{k} {in_dtype}): {e}")
_flydsl_kernel_cache[cache_key] = None

exe = _flydsl_kernel_cache[cache_key]
if exe is None:
raise RuntimeError(f"[FlyDSL] kernel compile returned None for ({m},{n},{k})")

def _as_i8(t):
return t.view(torch.int8) if "float8" in str(t.dtype) else t
Expand Down
22 changes: 12 additions & 10 deletions aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,21 +267,23 @@ def _vgpr_per_simd(gfx: str) -> int:
return 512


_MFMA_M = 16
_MFMA_N = 16
_THREADS_PER_TG = _WAVES_PER_WG * 64
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need update this for wave32



def _estimate_max_wpe(tile_m: int, tile_n: int, total_vgpr: int = 512) -> int:
"""Estimate max achievable waves_per_eu from C-accumulator VGPR pressure.

Each workgroup has _WAVES_PER_WG waves sharing the output tile.
Per-wave VGPR ≈ (accum share) * 1.5 (pipeline overhead for A/B buffers).
Preshuffle GEMM always uses 16x16 MFMA (4 VGPRs per thread per block).
Per-thread accum VGPRs = round_up(tile_m, 16) * round_up(tile_n, 16) / 256.
Estimated total ≈ accum * 1.5 (pipeline overhead for A/B buffers).
Returns the max waves_per_eu that the register file can support.
"""
mfma_m = 16 if tile_m < 32 else 32
mfma_n = 16 if tile_n < 32 else 32
vgpr_per_mfma = 16 if (mfma_m >= 32 and mfma_n >= 32) else 4
blocks_m = math.ceil(tile_m / mfma_m)
blocks_n = math.ceil(tile_n / mfma_n)
c_vgprs_total = blocks_m * blocks_n * vgpr_per_mfma
c_per_wave = c_vgprs_total / _WAVES_PER_WG
est_per_wave = c_per_wave * 1.5
padded_m = math.ceil(tile_m / _MFMA_M) * _MFMA_M
padded_n = math.ceil(tile_n / _MFMA_N) * _MFMA_N
c_per_thread = padded_m * padded_n // _THREADS_PER_TG
est_per_wave = c_per_thread * 1.5
return int(total_vgpr / max(est_per_wave, 1))


Expand Down
28 changes: 21 additions & 7 deletions aiter/ops/flydsl/kernels/preshuffle_gemm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Preshuffle GEMM kernel using the @flyc.kernel API."""

import functools

import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.compiler.kernel_function import CompilationContext
Expand Down Expand Up @@ -118,10 +120,10 @@ def _get_preload(tile_m, tile_n, tile_k):
)


@functools.lru_cache(maxsize=1024)
def compile_preshuffle_gemm_a8(
*,
M: int = 0,
N: int = 0,
N: int,
K: int,
tile_m: int,
tile_n: int,
Expand All @@ -140,8 +142,8 @@ def compile_preshuffle_gemm_a8(
Returns a JitFunction that auto-compiles and executes when called.
Signature: launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M, N, stream)

Compile-time constants: K, tile_m/n/k, in_dtype, out_dtype (determine loop structure).
Runtime parameters: M, N (passed as i32 kernel args).
Compile-time constants: N, K, tile_m/n/k, in_dtype, out_dtype (determine loop structure).
Runtime parameters: M (passed as i32 kernel arg).
Comment thread
solinzby1 marked this conversation as resolved.

Args:
out_dtype: Output element type, "fp16" or "bf16" (default: "fp16").
Expand Down Expand Up @@ -177,6 +179,19 @@ def compile_preshuffle_gemm_a8(
a_elem_vec_pack = 2 if is_fp4 else 1
b_elem_vec_pack = 2 if is_fp4 else 1

KERNEL_NAME = (
f"preshuffle_gemm_{in_dtype}_{out_dtype}"
Comment thread
solinzby1 marked this conversation as resolved.
f"_t{tile_m}x{tile_n}x{tile_k}"
f"_lds{lds_stage}"
f"_pl{dsrd_preload}x{dvmem_preload}"
)
if use_cshuffle_epilog:
KERNEL_NAME += "_csh"
if use_async_copy:
KERNEL_NAME += "_async"
if waves_per_eu is not None:
KERNEL_NAME += f"_wpe{waves_per_eu}"

tile_k_bytes = int(tile_k) * int(elem_bytes)

if (tile_k_bytes % 64) != 0:
Expand Down Expand Up @@ -1638,6 +1653,7 @@ def launch_gemm(
gx = (i32_m + (tile_m - 1)) // tile_m
gy = i32_n // tile_n

kernel_gemm._func.__name__ = KERNEL_NAME
launcher = kernel_gemm(
arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, i32_m, i32_n
)
Expand All @@ -1660,8 +1676,7 @@ def launch_gemm(

def compile_preshuffle_gemm_w4(
*,
M: int = 0,
N: int = 0,
N: int,
K: int,
tile_m: int,
tile_n: int,
Expand All @@ -1684,7 +1699,6 @@ def compile_preshuffle_gemm_w4(
if str(get_hip_arch()) != "gfx950":
raise RuntimeError(f"FP4 GEMM requires gfx950, got {get_hip_arch()}")
inner = compile_preshuffle_gemm_a8(
M=M,
N=N,
K=K,
tile_m=tile_m,
Expand Down
Loading