Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def getLogger():


logger = getLogger()
AITER_AOT_IMPORT = os.getenv("AITER_AOT_IMPORT", "0") == "1"

# Use bundled pre-compiled FlyDSL cache unless the user overrides via env var.
_flydsl_cache = os.path.join(os.path.dirname(__file__), "jit", "flydsl_cache")
Expand All @@ -65,6 +66,8 @@ def getLogger():

if sys.platform == "win32":
logger.info("Windows: CK and HIP ops are not available. Triton ops only.")
elif AITER_AOT_IMPORT:
from .jit import core as core # noqa: E402
else:
try:
from .jit import core as core # noqa: E402
Expand Down
93 changes: 93 additions & 0 deletions aiter/aot/flydsl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
import os
from typing import Any, Callable, Iterator

_CU_NUM_TO_ARCH = {
80: "gfx942",
304: "gfx942",
256: "gfx950",
}


def cu_num_to_arch(cu_num: int, default: str = "gfx950") -> str:
"""Map compute-unit count to GPU architecture string."""
return _CU_NUM_TO_ARCH.get(cu_num, default)


def job_identity(job: dict[str, Any]) -> tuple:
return tuple(sorted(job.items()))
Expand Down Expand Up @@ -51,3 +62,85 @@ def compile_only_env() -> Iterator[None]:
os.environ.pop("COMPILE_ONLY", None)
else:
os.environ["COMPILE_ONLY"] = prev


@contextmanager
def override_env(var_name: str, value: str | None) -> Iterator[None]:
prev = os.environ.get(var_name)
if value is None:
os.environ.pop(var_name, None)
else:
os.environ[var_name] = value
try:
yield
finally:
if prev is None:
os.environ.pop(var_name, None)
else:
os.environ[var_name] = prev


def run_aot_worker(kind):
"""Worker for ProcessPoolExecutor — runs in a child process."""
if kind == "moe":
from .moe import (
DEFAULT_CSVS,
compile_one_config,
parse_csv,
)
else:
from .gemm import (
DEFAULT_CSVS,
compile_one_config,
parse_csv,
)

label = f"FlyDSL {kind.upper()} AOT"
jobs = collect_aot_jobs(DEFAULT_CSVS, parse_csv)
if not jobs:
return label, 0, 0
cache_dir = os.environ.get("FLYDSL_RUNTIME_CACHE_DIR", "~/.flydsl/cache")
print(f"[aiter] {label}: {len(jobs)} kernels to compile (cache: {cache_dir})")
results = [compile_one_config(**job) for job in jobs]
ok = sum(1 for r in results if r["compile_time"] is not None)
fail = len(results) - ok
print(f"[aiter] {label}: compiled {ok} ok, {fail} failed")
return label, ok, fail


def start_aot(cache_dir: str):
"""Start FlyDSL AOT compilation in background processes.

Returns (pool, futures_dict) — caller must call ``wait_aot``
to collect results and raise on failure.
"""
from concurrent.futures import ProcessPoolExecutor

os.makedirs(cache_dir, exist_ok=True)
os.environ["FLYDSL_RUNTIME_CACHE_DIR"] = cache_dir

pool = ProcessPoolExecutor(max_workers=2)
futures = {
pool.submit(run_aot_worker, "moe"): "MoE",
pool.submit(run_aot_worker, "gemm"): "GEMM",
}
return pool, futures


def wait_aot(pool, futures):
"""Wait for FlyDSL AOT workers and raise on any failure."""
try:
errors = []
for future in futures:
try:
label, ok, fail = future.result()
if fail > 0:
errors.append(f"{label}: {fail} compile failure(s)")
except Exception as worker_err:
errors.append(
f"FlyDSL {futures[future]} AOT worker crashed: {worker_err}"
)
if errors:
raise AssertionError("[aiter] FlyDSL AOT failures: " + "; ".join(errors))
finally:
pool.shutdown(wait=False)
65 changes: 35 additions & 30 deletions aiter/aot/flydsl/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@
import time
from typing import Dict, Optional

from aiter.aot.flydsl.common import collect_aot_jobs, compile_only_env, job_identity
import flydsl.expr as fx

from aiter.aot.flydsl.common import (
collect_aot_jobs,
compile_only_env,
cu_num_to_arch,
job_identity,
override_env,
)
from aiter.jit.core import AITER_CONFIGS
from aiter.jit.utils.chip_info import get_gfx
from aiter.ops.flydsl.gemm_kernels import get_flydsl_splitk_hgemm_kernel_params
from aiter.ops.flydsl.kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel
from aiter.ops.flydsl.kernels.preshuffle_gemm import compile_preshuffle_gemm_a8
Expand All @@ -54,6 +61,7 @@
AITER_CONFIGS.AITER_CONFIG_BF16_BATCHED_GEMM_FILE,
AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE,
]
GEMM_AOT_ARCH_DEFAULT = "gfx950"

_PRESHUFFLE_RE = re.compile(
r"^flydsl_bpreshuflle_"
Expand Down Expand Up @@ -129,6 +137,7 @@ def parse_csv(csv_path: str):
m = int(row["M"])
n = int(row["N"])
k = int(row["K"])
cu_num = int(row.get("cu_num", "0"))

if kernel_name.startswith("flydsl_bpreshuflle_"):
params = _parse_preshuffle_kernel_name(kernel_name)
Expand All @@ -151,6 +160,7 @@ def parse_csv(csv_path: str):
"m": m,
"n": n,
"k": k,
"cu_num": cu_num,
"has_bias": _parse_bool(row.get("bias")),
**params,
}
Expand Down Expand Up @@ -178,14 +188,8 @@ def _torch_dtype_for_kernel(dtype_name: str):


def _compile_executable_to_cache(exe, *args) -> None:
compile_fn = getattr(exe, "compile", None)
if compile_fn is None:
import flydsl.compiler as flyc

compile_fn = flyc.compile
args = (exe, *args)
with compile_only_env():
compile_fn(*args)
exe(*args)


def _compile_hgemm_to_cache(
Expand Down Expand Up @@ -218,16 +222,10 @@ def _compile_hgemm_to_cache(

import torch

dev = torch.device("cuda")
has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
dev = torch.device("cuda") if has_cuda else torch.device("cpu")
torch_dtype = _torch_dtype_for_kernel(dtype)

current_gfx = get_gfx()
if target_gfx != current_gfx:
print(
f" [WARN] Kernel targets {target_gfx} but current target is {current_gfx}; "
"compiling with current target parameters"
)

out = torch.empty((m, n), device=dev, dtype=torch_dtype)
a = torch.empty((m, k), device=dev, dtype=torch_dtype)
b = torch.empty((n, k), device=dev, dtype=torch_dtype)
Expand All @@ -237,7 +235,7 @@ def _compile_hgemm_to_cache(
device=dev,
dtype=torch.int32,
)
stream = torch.cuda.current_stream(device=dev)
stream = fx.Stream(torch.cuda.current_stream(device=dev) if has_cuda else 0)

exe = compile_flydsl_hgemm_kernel(
dtype,
Expand Down Expand Up @@ -285,7 +283,8 @@ def _compile_preshuffle_to_cache(

import torch

dev = torch.device("cuda")
has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
dev = torch.device("cuda") if has_cuda else torch.device("cpu")
out_torch_dtype = _torch_dtype_for_kernel(out_dtype)

# FlyDSL preshuffle kernels consume raw quantized bytes for fp8/int8 paths.
Expand All @@ -294,7 +293,7 @@ def _compile_preshuffle_to_cache(
out = torch.empty((m * n,), device=dev, dtype=out_torch_dtype)
scale_a = torch.empty((max(m, 1),), device=dev, dtype=torch.float32)
scale_b = torch.empty((max(n, 1),), device=dev, dtype=torch.float32)
stream = torch.cuda.current_stream(device=dev)
stream = fx.Stream(torch.cuda.current_stream(device=dev) if has_cuda else 0)

exe = compile_preshuffle_gemm_a8(
N=n,
Expand All @@ -313,31 +312,36 @@ def _compile_preshuffle_to_cache(


def compile_one_config(
kernel_name: str, kind: str, m: int, n: int, k: int, **kwargs
kernel_name: str, kind: str, m: int, n: int, k: int, cu_num: int = 0, **kwargs
) -> dict:
"""Compile one GEMM kernel configuration and save it to cache."""
aot_arch = cu_num_to_arch(cu_num, default=GEMM_AOT_ARCH_DEFAULT)
shape_str = f"{kernel_name} M={m} N={n} K={k}"
result = {
"kernel_name": kernel_name,
"kind": kind,
"shape": shape_str,
"compile_time": None,
"compile_arch": aot_arch,
}

t0 = time.time()
try:
if kind == "hgemm":
_compile_hgemm_to_cache(m=m, n=n, k=k, **kwargs)
elif kind == "preshuffle":
_compile_preshuffle_to_cache(m=m, n=n, k=k, **kwargs)
else:
raise ValueError(f"Unknown GEMM AOT kind: {kind}")
with override_env("ARCH", aot_arch), override_env("FLYDSL_GPU_ARCH", aot_arch):
if kind == "hgemm":
hgemm_kwargs = dict(kwargs)
hgemm_kwargs["target_gfx"] = aot_arch
_compile_hgemm_to_cache(m=m, n=n, k=k, **hgemm_kwargs)
elif kind == "preshuffle":
_compile_preshuffle_to_cache(m=m, n=n, k=k, **kwargs)
else:
raise ValueError(f"Unknown GEMM AOT kind: {kind}")

elapsed = time.time() - t0
result["compile_time"] = elapsed
print(f" [OK] compile {elapsed:6.1f}s {shape_str}")
print(f" [OK] compile {elapsed:6.1f}s {shape_str} arch={aot_arch}")
except Exception as e:
print(f" [FAIL] compile {shape_str}: {e}")
print(f" [FAIL] compile {shape_str} arch={aot_arch}: {e}")

return result

Expand Down Expand Up @@ -365,7 +369,7 @@ def main():
cache_dir = os.path.expanduser(
os.environ.get("FLYDSL_RUNTIME_CACHE_DIR", "~/.flydsl/cache")
)
arch = os.environ.get("ARCH") or os.environ.get("GPU_ARCHS") or get_gfx()
arch = os.environ.get("ARCH") or os.environ.get("GPU_ARCHS") or "(auto-detect)"

all_jobs = collect_aot_jobs(csv_paths, parse_csv)

Expand All @@ -380,6 +384,7 @@ def main():
print(f" HGEMM jobs: {len(hgemm_jobs)}")
print(f" Preshuffle jobs: {len(preshuffle_jobs)}")
print(f" Total jobs: {len(all_jobs)}")
print(" Compile arch: (from cu_num)")
print(f" Cache dir: {cache_dir}")
print(f" Target arch: {arch}")
print("=" * 72)
Expand Down
Loading
Loading