Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aiter/ops/flydsl/gemm_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@

from ..shuffle import shuffle_weight
from .kernels.splitk_hgemm import compile_hgemm_kernel
from .utils import is_flydsl_available
from .utils import get_shared_memory_per_block, is_flydsl_available

__all__ = [
"flydsl_hgemm",
]

SPLIT_K_COUNTER_MAX_LEN = 128
SPLIT_K_SIGNAL_STATE_COUNT = 3
MAX_LDS_BYTES = 163840
FIXED_STAGE = 2
FIXED_C_TO_LDS = False
KERNEL_ASYNC_COPY = get_rocm_arch() != "gfx942"
Expand Down Expand Up @@ -333,10 +332,11 @@ def _validate_hgemm_tiling(
stages=stages,
b_to_lds=b_to_lds,
)
if lds_bytes > MAX_LDS_BYTES:
lds_limit = get_shared_memory_per_block(fallback_gfx=get_gfx())
if lds_bytes > lds_limit:
Comment thread
yzhou103 marked this conversation as resolved.
raise ValueError(
"Invalid tile combination: estimated LDS usage "
f"{lds_bytes} exceeds the hardware limit {MAX_LDS_BYTES}"
f"{lds_bytes} exceeds the hardware limit {lds_limit}"
)


Expand Down
26 changes: 9 additions & 17 deletions aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import math
import os

from aiter.ops.flydsl.utils import (
addressable_lds_bytes_for_gfx as _addressable_lds_bytes_for_gfx,
get_shared_memory_per_block,
)


def get_gfx():
"""Detect GPU arch: honour GPU_ARCHS env, fall back to chip_info, default gfx942."""
Expand Down Expand Up @@ -165,26 +170,13 @@ def kernel_instance_estimated_lds_bytes(ki: kernelInstance) -> int:
)


# Per-kernel LDS cap for tune filtering (must match LLVM AMDGPU
# getAddressableLocalMemorySize for the compile target).
# When arch cannot be parsed (no GPU, bad string), stay conservative for CDNA.
_FALLBACK_MAX_LDS_BYTES = 65536


def addressable_lds_bytes_for_gfx(gfx: str) -> int:
g = (gfx or "").strip().lower().split(":")[0]
if not g.startswith("gfx"):
return _FALLBACK_MAX_LDS_BYTES
if g.startswith("gfx950"):
return 163840
if g.startswith("gfx7") or g.startswith("gfx8"):
return 32768
return 65536
return _addressable_lds_bytes_for_gfx(gfx)


def max_lds_bytes_for_tune() -> int:
"""Addressable LDS limit for current target (from ``get_gfx()``)."""
return addressable_lds_bytes_for_gfx(get_gfx())
"""Addressable LDS limit for current target."""
return get_shared_memory_per_block(fallback_gfx=get_gfx())


# fmt: off
Expand Down Expand Up @@ -277,7 +269,7 @@ def _estimate_max_wpe(tile_m: int, tile_n: int, total_vgpr: int = 512) -> int:

Preshuffle GEMM always uses 16x16 MFMA (4 VGPRs per thread per block).
Per-thread accum VGPRs = round_up(tile_m, 16) * round_up(tile_n, 16) / 256.
Estimated total accum * 1.5 (pipeline overhead for A/B buffers).
Estimated total ~= accum * 1.5 (pipeline overhead for A/B buffers).
Returns the max waves_per_eu that the register file can support.
"""
padded_m = math.ceil(tile_m / _MFMA_M) * _MFMA_M
Expand Down
42 changes: 29 additions & 13 deletions aiter/ops/flydsl/kernels/splitk_hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr

from .tensor_shim import GTensor, STensor, _to_raw, get_dtype_in_kernel
from ..utils import get_shared_memory_per_block

SPLIT_K_COUNTER_MAX_LEN = 128
SPLIT_K_SIGNAL_STATE_COUNT = 3
Expand Down Expand Up @@ -178,6 +179,15 @@ def compile_hgemm_kernel(
assert BLOCK_MN_SIZE % BLOCK_VECS == 0
BLOCK_K_BYTES = BLOCK_K * DTYPE_BYTES

KERNEL_NAME = f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_S{STAGES}TN"
KERNEL_NAME += "_NA" if not ASYNC_COPY else "_AS"
if B_PRE_SHUFFLE:
KERNEL_NAME += "_BP"
if IS_SPLIT_K:
KERNEL_NAME += f"_SPK{SPLIT_K}"
if B_TO_LDS:
KERNEL_NAME += "_BS"

allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem")
smem_a_offset = allocator._align(allocator.ptr, 16)
AS_BYTES = STAGES * BLOCK_M * BLOCK_K * DTYPE_BYTES
Expand All @@ -188,22 +198,20 @@ def compile_hgemm_kernel(
smem_b_offset = allocator._align(allocator.ptr, 16)
allocator.ptr = smem_b_offset + STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES
SMEM_USE += STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES
assert SMEM_USE <= 163840
smem_limit = get_shared_memory_per_block(fallback_gfx=GPU_ARCH)
if SMEM_USE > smem_limit:
raise RuntimeError(
f"{KERNEL_NAME} requires {SMEM_USE} bytes LDS, "
f"but device limit is {smem_limit} bytes "
f"(arch={GPU_ARCH}, TILE_M={TILE_M}, TILE_N={TILE_N}, TILE_K={TILE_K}, "
f"SPLIT_K={SPLIT_K}, B_TO_LDS={B_TO_LDS})",
)
LDG_ASYNC_VEC_SIZE = DMA_BYTES // DTYPE_BYTES
LDG_A_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE
LDG_REG_A_COUNT_AS = BLOCK_MK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS
LDG_B_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE
LDG_REG_B_COUNT_AS = BLOCK_NK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS

KERNEL_NAME = f"hgemm_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_S{STAGES}TN"
KERNEL_NAME += "_NA" if not ASYNC_COPY else "_AS"
if B_PRE_SHUFFLE:
KERNEL_NAME += "_BP"
if IS_SPLIT_K:
KERNEL_NAME += f"_SPK{SPLIT_K}"
if B_TO_LDS:
KERNEL_NAME += "_BS"

@flyc.kernel
def hgemm_kernel(
C: fx.Tensor,
Expand Down Expand Up @@ -925,9 +933,17 @@ def _launch(*args, **kwargs):
def _compile(C, A, B, m, COUNTER, signal_state, stream):
with CompilationContext.compile_hints(_compile_hints):
if _compile_cache.get(m, None) is None:
_compile_cache[m] = flyc.compile(
launch_hgemm_kernel, C, A, B, m, COUNTER, signal_state, stream
)
try:
_compile_cache[m] = flyc.compile(
launch_hgemm_kernel, C, A, B, m, COUNTER, signal_state, stream
)
except Exception as e:
raise RuntimeError(
f"{KERNEL_NAME} failed "
f"(arch={GPU_ARCH}, n={n}, k={k}, TILE_M={TILE_M}, TILE_N={TILE_N}, "
f"TILE_K={TILE_K}, SPLIT_K={SPLIT_K}, B_TO_LDS={B_TO_LDS}, "
f"SMEM_USE={SMEM_USE}, SMEM_LIMIT={smem_limit}): {e}",
) from e
return _compile_cache[m]

_launch.compile = _compile
Expand Down
60 changes: 60 additions & 0 deletions aiter/ops/flydsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,66 @@
"""General utilities shared across all FlyDSL kernel families."""

import importlib.util
from functools import lru_cache

import torch

_FALLBACK_MAX_LDS_BYTES = 65536


def addressable_lds_bytes_for_gfx(gfx: str) -> int:
g = (gfx or "").strip().lower().split(":")[0]
if not g.startswith("gfx"):
return _FALLBACK_MAX_LDS_BYTES
if g.startswith("gfx950"):
return 163840
if g.startswith("gfx7") or g.startswith("gfx8"):
return 32768
return 65536


@lru_cache(maxsize=1)
def _default_cuda_device_index():
try:
return int(torch.cuda.current_device())
except Exception:
return None


@lru_cache(maxsize=None)
def _get_shared_memory_per_block_cached(device_index: int, fallback_gfx: str) -> int:
try:
props = torch.cuda.get_device_properties(device_index)
shared_memory_per_block = int(getattr(props, "shared_memory_per_block", 0) or 0)
if shared_memory_per_block > 0:
return shared_memory_per_block
return addressable_lds_bytes_for_gfx(
getattr(props, "gcnArchName", fallback_gfx)
)
except Exception:
return addressable_lds_bytes_for_gfx(fallback_gfx)


def get_shared_memory_per_block(device=None, fallback_gfx: str = "") -> int:
"""Return per-block shared memory/LDS limit for the active device."""
if device is None:
device = _default_cuda_device_index()
elif isinstance(device, torch.device):
if device.type != "cuda":
device = None
elif device.index is None:
device = _default_cuda_device_index()
else:
device = int(device.index)
else:
try:
device = int(device)
except Exception:
device = None

if device is None:
return addressable_lds_bytes_for_gfx(fallback_gfx)
return _get_shared_memory_per_block_cached(device, fallback_gfx)


def is_flydsl_available() -> bool:
Expand Down
40 changes: 24 additions & 16 deletions aiter/ops/gemm_op_a8w8.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ def gemm_a8w8_bpreshuffle_cktile(
) -> Tensor: ...


def _parse_flydsl_kernel_name(kernel_name: str):
"""Parse tile config from flydsl kernelName, e.g.
'flydsl_bpreshuflle_128x64x256_F8_F8_B16_2x0x1x1_default'
-> (tile_m=128, tile_n=64, tile_k=256, lds_stage=2, cshuffle=0, async_copy=1, wpe=1)
Returns None on parse failure.
"""
import re

m = re.match(
r"flydsl_bpreshuflle_(\d+)x(\d+)x(\d+)_\w+_\w+_\w+_(\d+)x(\d+)x(\d+)x(\d+)",
kernel_name,
)
if m is None:
return None
return tuple(int(m.group(i)) for i in range(1, 8))


def gemm_a8w8_bpreshuffle_flydsl(
XQ: Tensor,
WQ: Tensor,
Expand All @@ -112,22 +129,12 @@ def gemm_a8w8_bpreshuffle_flydsl(
config: dict,
) -> Tensor:
from .flydsl.gemm_kernels import flydsl_preshuffle_gemm_a8
from .flydsl.gemm_tune.flydsl_gemm_a8w8_bpreshuffle_common import (
kernels_list as kernels_list_flydsl,
)

kernel_id = config.get("kernelId")
if kernel_id is not None and kernel_id in kernels_list_flydsl:
ki = kernels_list_flydsl[kernel_id]
tm, tn, tk = ki.tile_m, ki.tile_n, ki.tile_k
lds, csh, acp, wpe = (
ki.lds_stage,
ki.use_cshuffle_epilog,
ki.use_async_copy,
ki.waves_per_eu,
)
else:
kernel_name = config.get("kernelName", "")
parsed = _parse_flydsl_kernel_name(str(kernel_name))
if parsed is None:
return gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Out)
tm, tn, tk, lds, csh, acp, wpe = parsed

flydsl_preshuffle_gemm_a8(
XQ.contiguous(),
Expand Down Expand Up @@ -488,8 +495,9 @@ def gemm_a8w8_ASM(
)
is not None
):
assert bias is not None, "Use asm gemm must give bias, please give a \
bias=torch.zeros(n,dtype=dtypes.fp32,device='cuda')"
assert (
bias is not None
), "Use asm gemm must give bias, please give a bias=torch.zeros(n,dtype=dtypes.fp32,device='cuda')"
splitK = asm_config["splitK"]
kernelName = asm_config["kernelName"]
Y = torch.empty(m, n, dtype=dtype, device=XQ.device)
Expand Down
10 changes: 5 additions & 5 deletions aiter/utility/base_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
f"error: no valid candidate found for {info_key}, please check the result or errRatio in all result file running with --profile_file"
)

if len(filtered_time) < topk:
topk = len(filtered_time)
print(f"choose {topk} kernels")
self.topk = topk
effective_topk = min(topk, len(filtered_time))
if effective_topk < topk:
print(f"choose {effective_topk} kernels")
self.topk = effective_topk
best_config = [
((info_key, *info_ex), us, max_err_ratio)
for info_ex, us, max_err_ratio in filtered_time[0:topk]
for info_ex, us, max_err_ratio in filtered_time[0:effective_topk]
]
if not best_config:
logger.info(f"No kernel can be used for {info_key}")
Expand Down
19 changes: 14 additions & 5 deletions aiter/utility/mp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@
from aiter import logger


def _is_mapping_error(exc: BaseException) -> bool:
return isinstance(exc, KeyError)


def _is_accelerator_error(exc: BaseException) -> bool:
return type(exc).__name__ == "AcceleratorError"


def worker(
gpu_id,
info,
Expand Down Expand Up @@ -36,7 +44,7 @@ def worker(
res, us = run_perftest(func, *args, **kwargs)
us = round(us, 4)

except RuntimeError as e:
except (RuntimeError, ValueError) as e:
print(f"run gpu func warning: info:{info}\t {e}", flush=True)
us = -1 # not support or error
max_err_ratio = 1.0
Expand All @@ -50,6 +58,8 @@ def worker(
if us == 0:
print(f"Warning: try run {max_retries} times, but still get 0!")
torch.cuda.synchronize()
if us == -1 or res is None:
return info, us, round(max_err_ratio, 4)
if ref is not None:
if isinstance(ref, torch.Tensor):
ref = [ref]
Expand Down Expand Up @@ -448,14 +458,13 @@ def add_dummy_result(k, results_list):
except Exception as e:
# Check if it's a process crash (segfault, memory fault, etc.)
error_type = type(e).__name__

# Special handling for KeyError (PID mapping issue)
is_mapping_error = error_type == "KeyError"
is_mapping_error = _is_mapping_error(e)
is_accelerator_error = _is_accelerator_error(e)
# not restart as this is not root use
if is_mapping_error:
error_msg = f"[Mapping Error] Task {k} - Process PID not in GPU map: {error_type} - {e}"
dummy_failed_tasks.append((k, "mapping error"))
elif error_type == "AcceleratorError":
elif is_accelerator_error:
# GPU fault (e.g. illegal memory access): worker returns exception instead of
# hanging. Unlike hang->timeout, the faulting worker may stay alive and accept
# more tasks on the same bad GPU. Break immediately to trigger restart and
Expand Down
Loading