diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index d6e829d91e..44a5fc3fe8 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -255,7 +255,7 @@ def _validate_hgemm_tiling( ) if pack_n != 1: raise ValueError( - "Current kernel only supports `pack_n=1`; " f"got pack_n={pack_n}" + f"Current kernel only supports `pack_n=1`; got pack_n={pack_n}" ) warp_atom_m = 16 @@ -664,7 +664,6 @@ def flydsl_hgemm( _flydsl_compile_fn = None _flydsl_import_done = False -_flydsl_kernel_cache: dict = {} def _get_compile_fn(): @@ -702,7 +701,7 @@ def flydsl_preshuffle_gemm_a8( use_async_copy: int = 0, waves_per_eu: int = 0, ) -> Tensor: - """Compile (cached) and run a FlyDSL preshuffle GEMM kernel.""" + """Compile (cached via lru_cache) and run a FlyDSL preshuffle GEMM kernel.""" compile_fn = _get_compile_fn() if compile_fn is None: raise RuntimeError("[FlyDSL] compile function not available") @@ -739,49 +738,19 @@ def flydsl_preshuffle_gemm_a8( f"[FlyDSL] unsupported output dtype {Out.dtype}; expected torch.bfloat16 or torch.float16" ) - cache_key = ( - m, - n, - k, - in_dtype, - out_dtype, - tile_m, - tile_n, - tile_k, - lds_stage, - use_cshuffle_epilog, - use_async_copy, - wpe, + exe = compile_fn( + N=n, + K=k, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + in_dtype=in_dtype, + out_dtype=out_dtype, + lds_stage=lds_stage, + use_cshuffle_epilog=bool(use_cshuffle_epilog), + use_async_copy=bool(use_async_copy), + waves_per_eu=wpe, ) - if cache_key not in _flydsl_kernel_cache: - try: - exe = compile_fn( - M=m, - N=n, - K=k, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - in_dtype=in_dtype, - out_dtype=out_dtype, - lds_stage=lds_stage, - use_cshuffle_epilog=bool(use_cshuffle_epilog), - use_async_copy=bool(use_async_copy), - waves_per_eu=wpe, - ) - _flydsl_kernel_cache[cache_key] = exe - logger.info( - f"[FlyDSL] compiled preshuffle GEMM ({m},{n},{k} {in_dtype} " - f"tile={tile_m}x{tile_n}x{tile_k} lds={lds_stage} csh={use_cshuffle_epilog} " - f"acp={use_async_copy} wpe={waves_per_eu})" - ) - except Exception as e: - logger.warning(f"[FlyDSL] compile failed ({m},{n},{k} {in_dtype}): {e}") - _flydsl_kernel_cache[cache_key] = None - - exe = _flydsl_kernel_cache[cache_key] - if exe is None: - raise RuntimeError(f"[FlyDSL] kernel compile returned None for ({m},{n},{k})") def _as_i8(t): return t.view(torch.int8) if "float8" in str(t.dtype) else t diff --git a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py index d1fac6aed7..3a2a2a5877 100644 --- a/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py +++ b/aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py @@ -267,21 +267,23 @@ def _vgpr_per_simd(gfx: str) -> int: return 512 +_MFMA_M = 16 +_MFMA_N = 16 +_THREADS_PER_TG = _WAVES_PER_WG * 64 + + def _estimate_max_wpe(tile_m: int, tile_n: int, total_vgpr: int = 512) -> int: """Estimate max achievable waves_per_eu from C-accumulator VGPR pressure. - Each workgroup has _WAVES_PER_WG waves sharing the output tile. - Per-wave VGPR ≈ (accum share) * 1.5 (pipeline overhead for A/B buffers). + Preshuffle GEMM always uses 16x16 MFMA (4 VGPRs per thread per block). + Per-thread accum VGPRs = round_up(tile_m, 16) * round_up(tile_n, 16) / 256. + Estimated total ≈ accum * 1.5 (pipeline overhead for A/B buffers). Returns the max waves_per_eu that the register file can support. """ - mfma_m = 16 if tile_m < 32 else 32 - mfma_n = 16 if tile_n < 32 else 32 - vgpr_per_mfma = 16 if (mfma_m >= 32 and mfma_n >= 32) else 4 - blocks_m = math.ceil(tile_m / mfma_m) - blocks_n = math.ceil(tile_n / mfma_n) - c_vgprs_total = blocks_m * blocks_n * vgpr_per_mfma - c_per_wave = c_vgprs_total / _WAVES_PER_WG - est_per_wave = c_per_wave * 1.5 + padded_m = math.ceil(tile_m / _MFMA_M) * _MFMA_M + padded_n = math.ceil(tile_n / _MFMA_N) * _MFMA_N + c_per_thread = padded_m * padded_n // _THREADS_PER_TG + est_per_wave = c_per_thread * 1.5 return int(total_vgpr / max(est_per_wave, 1)) diff --git a/aiter/ops/flydsl/kernels/preshuffle_gemm.py b/aiter/ops/flydsl/kernels/preshuffle_gemm.py index 3707c97ec7..309e832eac 100644 --- a/aiter/ops/flydsl/kernels/preshuffle_gemm.py +++ b/aiter/ops/flydsl/kernels/preshuffle_gemm.py @@ -1,5 +1,7 @@ """Preshuffle GEMM kernel using the @flyc.kernel API.""" +import functools + import flydsl.compiler as flyc import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext @@ -118,10 +120,10 @@ def _get_preload(tile_m, tile_n, tile_k): ) +@functools.lru_cache(maxsize=1024) def compile_preshuffle_gemm_a8( *, - M: int = 0, - N: int = 0, + N: int, K: int, tile_m: int, tile_n: int, @@ -140,8 +142,8 @@ def compile_preshuffle_gemm_a8( Returns a JitFunction that auto-compiles and executes when called. Signature: launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M, N, stream) - Compile-time constants: K, tile_m/n/k, in_dtype, out_dtype (determine loop structure). - Runtime parameters: M, N (passed as i32 kernel args). + Compile-time constants: N, K, tile_m/n/k, in_dtype, out_dtype (determine loop structure). + Runtime parameters: M (passed as i32 kernel arg). Args: out_dtype: Output element type, "fp16" or "bf16" (default: "fp16"). @@ -177,6 +179,19 @@ def compile_preshuffle_gemm_a8( a_elem_vec_pack = 2 if is_fp4 else 1 b_elem_vec_pack = 2 if is_fp4 else 1 + KERNEL_NAME = ( + f"preshuffle_gemm_{in_dtype}_{out_dtype}" + f"_t{tile_m}x{tile_n}x{tile_k}" + f"_lds{lds_stage}" + f"_pl{dsrd_preload}x{dvmem_preload}" + ) + if use_cshuffle_epilog: + KERNEL_NAME += "_csh" + if use_async_copy: + KERNEL_NAME += "_async" + if waves_per_eu is not None: + KERNEL_NAME += f"_wpe{waves_per_eu}" + tile_k_bytes = int(tile_k) * int(elem_bytes) if (tile_k_bytes % 64) != 0: @@ -1638,6 +1653,7 @@ def launch_gemm( gx = (i32_m + (tile_m - 1)) // tile_m gy = i32_n // tile_n + kernel_gemm._func.__name__ = KERNEL_NAME launcher = kernel_gemm( arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, i32_m, i32_n ) @@ -1660,8 +1676,7 @@ def launch_gemm( def compile_preshuffle_gemm_w4( *, - M: int = 0, - N: int = 0, + N: int, K: int, tile_m: int, tile_n: int, @@ -1684,7 +1699,6 @@ def compile_preshuffle_gemm_w4( if str(get_hip_arch()) != "gfx950": raise RuntimeError(f"FP4 GEMM requires gfx950, got {get_hip_arch()}") inner = compile_preshuffle_gemm_a8( - M=M, N=N, K=K, tile_m=tile_m,