diff --git a/aiter/__init__.py b/aiter/__init__.py index e1a7cf901c..b6dcb34a2f 100644 --- a/aiter/__init__.py +++ b/aiter/__init__.py @@ -57,6 +57,7 @@ def getLogger(): logger = getLogger() +AITER_AOT_IMPORT = os.getenv("AITER_AOT_IMPORT", "0") == "1" # Use bundled pre-compiled FlyDSL cache unless the user overrides via env var. _flydsl_cache = os.path.join(os.path.dirname(__file__), "jit", "flydsl_cache") @@ -65,6 +66,8 @@ def getLogger(): if sys.platform == "win32": logger.info("Windows: CK and HIP ops are not available. Triton ops only.") +elif AITER_AOT_IMPORT: + from .jit import core as core # noqa: E402 else: try: from .jit import core as core # noqa: E402 diff --git a/aiter/aot/flydsl/common.py b/aiter/aot/flydsl/common.py index d9e34f9190..c6ddfc8673 100644 --- a/aiter/aot/flydsl/common.py +++ b/aiter/aot/flydsl/common.py @@ -9,6 +9,17 @@ import os from typing import Any, Callable, Iterator +_CU_NUM_TO_ARCH = { + 80: "gfx942", + 304: "gfx942", + 256: "gfx950", +} + + +def cu_num_to_arch(cu_num: int, default: str = "gfx950") -> str: + """Map compute-unit count to GPU architecture string.""" + return _CU_NUM_TO_ARCH.get(cu_num, default) + def job_identity(job: dict[str, Any]) -> tuple: return tuple(sorted(job.items())) @@ -51,3 +62,85 @@ def compile_only_env() -> Iterator[None]: os.environ.pop("COMPILE_ONLY", None) else: os.environ["COMPILE_ONLY"] = prev + + +@contextmanager +def override_env(var_name: str, value: str | None) -> Iterator[None]: + prev = os.environ.get(var_name) + if value is None: + os.environ.pop(var_name, None) + else: + os.environ[var_name] = value + try: + yield + finally: + if prev is None: + os.environ.pop(var_name, None) + else: + os.environ[var_name] = prev + + +def run_aot_worker(kind): + """Worker for ProcessPoolExecutor — runs in a child process.""" + if kind == "moe": + from .moe import ( + DEFAULT_CSVS, + compile_one_config, + parse_csv, + ) + else: + from .gemm import ( + DEFAULT_CSVS, + compile_one_config, + parse_csv, + ) + + label = f"FlyDSL {kind.upper()} AOT" + jobs = collect_aot_jobs(DEFAULT_CSVS, parse_csv) + if not jobs: + return label, 0, 0 + cache_dir = os.environ.get("FLYDSL_RUNTIME_CACHE_DIR", "~/.flydsl/cache") + print(f"[aiter] {label}: {len(jobs)} kernels to compile (cache: {cache_dir})") + results = [compile_one_config(**job) for job in jobs] + ok = sum(1 for r in results if r["compile_time"] is not None) + fail = len(results) - ok + print(f"[aiter] {label}: compiled {ok} ok, {fail} failed") + return label, ok, fail + + +def start_aot(cache_dir: str): + """Start FlyDSL AOT compilation in background processes. + + Returns (pool, futures_dict) — caller must call ``wait_aot`` + to collect results and raise on failure. + """ + from concurrent.futures import ProcessPoolExecutor + + os.makedirs(cache_dir, exist_ok=True) + os.environ["FLYDSL_RUNTIME_CACHE_DIR"] = cache_dir + + pool = ProcessPoolExecutor(max_workers=2) + futures = { + pool.submit(run_aot_worker, "moe"): "MoE", + pool.submit(run_aot_worker, "gemm"): "GEMM", + } + return pool, futures + + +def wait_aot(pool, futures): + """Wait for FlyDSL AOT workers and raise on any failure.""" + try: + errors = [] + for future in futures: + try: + label, ok, fail = future.result() + if fail > 0: + errors.append(f"{label}: {fail} compile failure(s)") + except Exception as worker_err: + errors.append( + f"FlyDSL {futures[future]} AOT worker crashed: {worker_err}" + ) + if errors: + raise AssertionError("[aiter] FlyDSL AOT failures: " + "; ".join(errors)) + finally: + pool.shutdown(wait=False) diff --git a/aiter/aot/flydsl/gemm.py b/aiter/aot/flydsl/gemm.py index 6163d6c77a..66b42ff274 100644 --- a/aiter/aot/flydsl/gemm.py +++ b/aiter/aot/flydsl/gemm.py @@ -36,9 +36,16 @@ import time from typing import Dict, Optional -from aiter.aot.flydsl.common import collect_aot_jobs, compile_only_env, job_identity +import flydsl.expr as fx + +from aiter.aot.flydsl.common import ( + collect_aot_jobs, + compile_only_env, + cu_num_to_arch, + job_identity, + override_env, +) from aiter.jit.core import AITER_CONFIGS -from aiter.jit.utils.chip_info import get_gfx from aiter.ops.flydsl.gemm_kernels import get_flydsl_splitk_hgemm_kernel_params from aiter.ops.flydsl.kernels.hgemm_dispatch import compile_flydsl_hgemm_kernel from aiter.ops.flydsl.kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 @@ -54,6 +61,7 @@ AITER_CONFIGS.AITER_CONFIG_BF16_BATCHED_GEMM_FILE, AITER_CONFIGS.AITER_CONFIG_GEMM_BF16_FILE, ] +GEMM_AOT_ARCH_DEFAULT = "gfx950" _PRESHUFFLE_RE = re.compile( r"^flydsl_bpreshuflle_" @@ -129,6 +137,7 @@ def parse_csv(csv_path: str): m = int(row["M"]) n = int(row["N"]) k = int(row["K"]) + cu_num = int(row.get("cu_num", "0")) if kernel_name.startswith("flydsl_bpreshuflle_"): params = _parse_preshuffle_kernel_name(kernel_name) @@ -151,6 +160,7 @@ def parse_csv(csv_path: str): "m": m, "n": n, "k": k, + "cu_num": cu_num, "has_bias": _parse_bool(row.get("bias")), **params, } @@ -178,14 +188,8 @@ def _torch_dtype_for_kernel(dtype_name: str): def _compile_executable_to_cache(exe, *args) -> None: - compile_fn = getattr(exe, "compile", None) - if compile_fn is None: - import flydsl.compiler as flyc - - compile_fn = flyc.compile - args = (exe, *args) with compile_only_env(): - compile_fn(*args) + exe(*args) def _compile_hgemm_to_cache( @@ -218,16 +222,10 @@ def _compile_hgemm_to_cache( import torch - dev = torch.device("cuda") + has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0 + dev = torch.device("cuda") if has_cuda else torch.device("cpu") torch_dtype = _torch_dtype_for_kernel(dtype) - current_gfx = get_gfx() - if target_gfx != current_gfx: - print( - f" [WARN] Kernel targets {target_gfx} but current target is {current_gfx}; " - "compiling with current target parameters" - ) - out = torch.empty((m, n), device=dev, dtype=torch_dtype) a = torch.empty((m, k), device=dev, dtype=torch_dtype) b = torch.empty((n, k), device=dev, dtype=torch_dtype) @@ -237,7 +235,7 @@ def _compile_hgemm_to_cache( device=dev, dtype=torch.int32, ) - stream = torch.cuda.current_stream(device=dev) + stream = fx.Stream(torch.cuda.current_stream(device=dev) if has_cuda else 0) exe = compile_flydsl_hgemm_kernel( dtype, @@ -285,7 +283,8 @@ def _compile_preshuffle_to_cache( import torch - dev = torch.device("cuda") + has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0 + dev = torch.device("cuda") if has_cuda else torch.device("cpu") out_torch_dtype = _torch_dtype_for_kernel(out_dtype) # FlyDSL preshuffle kernels consume raw quantized bytes for fp8/int8 paths. @@ -294,7 +293,7 @@ def _compile_preshuffle_to_cache( out = torch.empty((m * n,), device=dev, dtype=out_torch_dtype) scale_a = torch.empty((max(m, 1),), device=dev, dtype=torch.float32) scale_b = torch.empty((max(n, 1),), device=dev, dtype=torch.float32) - stream = torch.cuda.current_stream(device=dev) + stream = fx.Stream(torch.cuda.current_stream(device=dev) if has_cuda else 0) exe = compile_preshuffle_gemm_a8( N=n, @@ -313,31 +312,36 @@ def _compile_preshuffle_to_cache( def compile_one_config( - kernel_name: str, kind: str, m: int, n: int, k: int, **kwargs + kernel_name: str, kind: str, m: int, n: int, k: int, cu_num: int = 0, **kwargs ) -> dict: """Compile one GEMM kernel configuration and save it to cache.""" + aot_arch = cu_num_to_arch(cu_num, default=GEMM_AOT_ARCH_DEFAULT) shape_str = f"{kernel_name} M={m} N={n} K={k}" result = { "kernel_name": kernel_name, "kind": kind, "shape": shape_str, "compile_time": None, + "compile_arch": aot_arch, } t0 = time.time() try: - if kind == "hgemm": - _compile_hgemm_to_cache(m=m, n=n, k=k, **kwargs) - elif kind == "preshuffle": - _compile_preshuffle_to_cache(m=m, n=n, k=k, **kwargs) - else: - raise ValueError(f"Unknown GEMM AOT kind: {kind}") + with override_env("ARCH", aot_arch), override_env("FLYDSL_GPU_ARCH", aot_arch): + if kind == "hgemm": + hgemm_kwargs = dict(kwargs) + hgemm_kwargs["target_gfx"] = aot_arch + _compile_hgemm_to_cache(m=m, n=n, k=k, **hgemm_kwargs) + elif kind == "preshuffle": + _compile_preshuffle_to_cache(m=m, n=n, k=k, **kwargs) + else: + raise ValueError(f"Unknown GEMM AOT kind: {kind}") elapsed = time.time() - t0 result["compile_time"] = elapsed - print(f" [OK] compile {elapsed:6.1f}s {shape_str}") + print(f" [OK] compile {elapsed:6.1f}s {shape_str} arch={aot_arch}") except Exception as e: - print(f" [FAIL] compile {shape_str}: {e}") + print(f" [FAIL] compile {shape_str} arch={aot_arch}: {e}") return result @@ -365,7 +369,7 @@ def main(): cache_dir = os.path.expanduser( os.environ.get("FLYDSL_RUNTIME_CACHE_DIR", "~/.flydsl/cache") ) - arch = os.environ.get("ARCH") or os.environ.get("GPU_ARCHS") or get_gfx() + arch = os.environ.get("ARCH") or os.environ.get("GPU_ARCHS") or "(auto-detect)" all_jobs = collect_aot_jobs(csv_paths, parse_csv) @@ -380,6 +384,7 @@ def main(): print(f" HGEMM jobs: {len(hgemm_jobs)}") print(f" Preshuffle jobs: {len(preshuffle_jobs)}") print(f" Total jobs: {len(all_jobs)}") + print(" Compile arch: (from cu_num)") print(f" Cache dir: {cache_dir}") print(f" Target arch: {arch}") print("=" * 72) diff --git a/aiter/aot/flydsl/moe.py b/aiter/aot/flydsl/moe.py index 589814909c..2ac3485b5d 100644 --- a/aiter/aot/flydsl/moe.py +++ b/aiter/aot/flydsl/moe.py @@ -28,7 +28,13 @@ import sys import time -from aiter.aot.flydsl.common import collect_aot_jobs, compile_only_env, job_identity +from aiter.aot.flydsl.common import ( + collect_aot_jobs, + compile_only_env, + cu_num_to_arch, + job_identity, + override_env, +) from aiter.jit.core import AITER_CONFIGS from aiter.ops.flydsl.moe_kernels import ( compile_flydsl_moe_stage1, @@ -45,6 +51,7 @@ DEFAULT_CSVS = [ AITER_CONFIGS.AITER_CONFIG_FMOE_FILE, ] +MOE_AOT_ARCH_DEFAULT = "gfx950" def parse_csv(csv_path: str): @@ -68,6 +75,7 @@ def parse_csv(csv_path: str): experts = int(row["expert"]) topk = int(row["topk"]) doweight_stage1 = bool(int(row.get("doweight_stage1", "0"))) + cu_num = int(row.get("cu_num", "0")) for col in ("kernelName1", "kernelName2"): name = row.get(col, "").strip() @@ -81,6 +89,7 @@ def parse_csv(csv_path: str): "experts": experts, "topk": topk, "doweight_stage1": doweight_stage1, + "cu_num": cu_num, } key = job_identity(job) if key in seen: @@ -113,11 +122,11 @@ def _precompile_to_cache( waves_per_eu: int = 3, k_batch: int = 1, b_nt: int = 2, - gate_only: bool = False, - fuse_fp4_quant: bool = False, + gate_mode: str = "separated", mode: str = "atomic", persist: bool = False, sort_block_m: int = 0, + cu_num: int = 0, **kwargs, ): """Trigger MLIR compilation with dummy tensors and COMPILE_ONLY=1. @@ -129,7 +138,8 @@ def _precompile_to_cache( """ import torch - dev = torch.device("cuda") + dev = torch.device("cpu") + _stream = 0 is_fp4 = b_dtype == "fp4" tokens = tile_m E = experts @@ -141,8 +151,15 @@ def _precompile_to_cache( num_valid_ids = torch.zeros(1, device=dev, dtype=torch.int32) sw = torch.zeros(tokens * topk, device=dev, dtype=torch.float32) - with compile_only_env(): + _cu_num_str = str(cu_num) if cu_num > 0 else None + with compile_only_env(), override_env("CU_NUM", _cu_num_str): + # Clear cached CU count so get_cu_num() re-reads the env var. + from aiter.jit.utils.chip_info import get_cu_num + + get_cu_num.cache_clear() + if stage == 1: + _is_splitk = k_batch > 1 n_in = inter_dim * 2 if is_fp4 else inter_dim k_in = model_dim @@ -174,6 +191,7 @@ def _precompile_to_cache( k_in, _grid_y, dev, + stream=_stream, ) else: out = torch.zeros( @@ -199,6 +217,7 @@ def _precompile_to_cache( n_in, k_in, _grid_y, + stream=_stream, ) exe = compile_flydsl_moe_stage1( @@ -216,13 +235,12 @@ def _precompile_to_cache( waves_per_eu=waves_per_eu, k_batch=k_batch, b_nt=b_nt, - gate_only=gate_only, - fuse_fp4_quant=fuse_fp4_quant and not _is_splitk, - fuse_sort_scale=fuse_fp4_quant and not _is_splitk, + gate_mode=gate_mode, ) _run_compiled(exe, args) elif stage == 2: + accumulate = mode != "reduce" _persist_m = -1 if persist else 4 n_in = model_dim @@ -253,6 +271,7 @@ def _precompile_to_cache( k_in, _grid_y, dev, + stream=_stream, ) else: out = torch.zeros(tokens * model_dim, device=dev, dtype=torch.bfloat16) @@ -274,6 +293,7 @@ def _precompile_to_cache( n_in, k_in, _grid_y, + stream=_stream, ) exe = compile_flydsl_moe_stage2( @@ -296,7 +316,13 @@ def _precompile_to_cache( def compile_one_config( - kernel_name: str, model_dim: int, inter_dim: int, experts: int, topk: int, **kwargs + kernel_name: str, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + cu_num: int = 0, + **kwargs, ) -> dict: """Compile one MoE kernel configuration and save to cache. @@ -305,27 +331,35 @@ def compile_one_config( Returns a dict with timing info. """ + aot_arch = cu_num_to_arch(cu_num, default=MOE_AOT_ARCH_DEFAULT) shape_str = ( f"{kernel_name} " f"model_dim={model_dim} inter_dim={inter_dim} " f"E={experts} topk={topk}" ) - result = {"kernel_name": kernel_name, "shape": shape_str, "compile_time": None} + result = { + "kernel_name": kernel_name, + "shape": shape_str, + "compile_time": None, + "compile_arch": aot_arch, + } t0 = time.time() try: - _precompile_to_cache( - model_dim=model_dim, - inter_dim=inter_dim, - experts=experts, - topk=topk, - **kwargs, - ) + with override_env("ARCH", aot_arch), override_env("FLYDSL_GPU_ARCH", aot_arch): + _precompile_to_cache( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + cu_num=cu_num, + **kwargs, + ) elapsed = time.time() - t0 result["compile_time"] = elapsed - print(f" [OK] compile {elapsed:6.1f}s {shape_str}") + print(f" [OK] compile {elapsed:6.1f}s {shape_str} arch={aot_arch}") except Exception as e: - print(f" [FAIL] compile {shape_str}: {e}") + print(f" [FAIL] compile {shape_str} arch={aot_arch}: {e}") return result @@ -353,13 +387,12 @@ def main(): cache_dir = os.path.expanduser( os.environ.get("FLYDSL_RUNTIME_CACHE_DIR", "~/.flydsl/cache") ) - arch = os.environ.get("ARCH", "(auto-detect)") + arch = os.environ.get("ARCH") or os.environ.get("GPU_ARCHS") or "(auto-detect)" all_jobs = collect_aot_jobs(csv_paths, parse_csv) stage1_jobs = [j for j in all_jobs if j["stage"] == 1] stage2_jobs = [j for j in all_jobs if j["stage"] == 2] - print("=" * 72) print("FlyDSL MoE AOT Pre-compilation") print("=" * 72) @@ -368,6 +401,7 @@ def main(): print(f" Stage1 jobs: {len(stage1_jobs)}") print(f" Stage2 jobs: {len(stage2_jobs)}") print(f" Total jobs: {len(all_jobs)}") + print(" Compile arch: (from cu_num)") print(f" Cache dir: {cache_dir}") print(f" Target arch: {arch}") print("=" * 72) diff --git a/aiter/ops/flydsl/gemm_kernels.py b/aiter/ops/flydsl/gemm_kernels.py index 58139575f2..ab25f91cf0 100644 --- a/aiter/ops/flydsl/gemm_kernels.py +++ b/aiter/ops/flydsl/gemm_kernels.py @@ -13,7 +13,6 @@ from torch import Tensor from aiter import logger -from aiter.utility import dtypes from flydsl.runtime.device import get_rocm_arch from aiter.jit.utils.chip_info import get_gfx @@ -28,6 +27,13 @@ "flydsl_hgemm", ] + +def _get_dtypes(): + from aiter.utility import dtypes + + return dtypes + + SPLIT_K_COUNTER_MAX_LEN = 128 SPLIT_K_SIGNAL_STATE_COUNT = 3 FIXED_STAGE = 2 @@ -965,6 +971,7 @@ def flydsl_preshuffle_gemm_a8( compile_fn = _get_compile_fn() if compile_fn is None: raise RuntimeError("[FlyDSL] compile function not available") + dtypes = _get_dtypes() m, k = XQ.shape[0], XQ.shape[-1] n = WQ.shape[0] diff --git a/aiter/ops/flydsl/kernels/mfma_epilogues.py b/aiter/ops/flydsl/kernels/mfma_epilogues.py index e4c4d0f559..052bd4a6a8 100644 --- a/aiter/ops/flydsl/kernels/mfma_epilogues.py +++ b/aiter/ops/flydsl/kernels/mfma_epilogues.py @@ -183,7 +183,7 @@ def c_shuffle_epilog( def _write_row_split(mi: int, ii: int, row_in_tile, row): row_base_lds = row_in_tile * _half_n_idx - _if_g = scf.IfOp(_is_group_b) + _if_g = scf.IfOp(_is_group_b, has_else=True) with ir.InsertionPoint(_if_g.then_block): write_row_to_lds( mi=mi, @@ -266,7 +266,7 @@ def _do_store_row_split(): col_pair0_local = col_base_nr + (n_lane_s * c_evec) lds_idx = row_base_lds + col_pair0_local - _if_ld = scf.IfOp(_is_group_b, [vec_frag]) + _if_ld = scf.IfOp(_is_group_b, [vec_frag], has_else=True) with ir.InsertionPoint(_if_ld.then_block): fb = vector.load_op(vec_frag, lds_out_split, [lds_idx]) scf.YieldOp([fb]) diff --git a/aiter/ops/flydsl/moe_kernels.py b/aiter/ops/flydsl/moe_kernels.py index 0d33bf687c..bd4554e3de 100644 --- a/aiter/ops/flydsl/moe_kernels.py +++ b/aiter/ops/flydsl/moe_kernels.py @@ -7,12 +7,18 @@ import re from typing import Dict, Optional -from aiter.utility import dtypes import torch _KERNEL_PARAMS: Dict[str, Dict] = {} + +def _get_dtypes(): + from aiter.utility import dtypes + + return dtypes + + _SUFFIX_RE = re.compile(r"(?P_fp4)?(?P_fp8)?(?:_sbm(?P\d+))?$") @@ -359,9 +365,12 @@ def _s1_args_fp4( size_expert_ids_in, dev, bias=None, + stream=None, ): empty_f32 = torch.empty(0, device=dev, dtype=torch.float32) _bias = bias if bias is not None else empty_f32 + if stream is None: + stream = torch.cuda.current_stream() return ( _view_safe(out), _view_safe(a), @@ -378,7 +387,7 @@ def _s1_args_fp4( n_in, k_in, size_expert_ids_in, - torch.cuda.current_stream(), + stream, ) @@ -396,7 +405,10 @@ def _s1_args_std( n_in, k_in, size_expert_ids_in, + stream=None, ): + if stream is None: + stream = torch.cuda.current_stream() return ( out, a, @@ -411,7 +423,7 @@ def _s1_args_std( n_in, k_in, size_expert_ids_in, - torch.cuda.current_stream(), + stream, ) @@ -431,12 +443,15 @@ def _s2_args_fp4( blocks, dev, bias=None, + stream=None, ): _bias = ( bias.view(-1) if bias is not None else torch.empty(0, device=dev, dtype=torch.float32) ) + if stream is None: + stream = torch.cuda.current_stream() return ( _view_safe(target), _view_safe(a), @@ -452,7 +467,7 @@ def _s2_args_fp4( n_in, k_in, blocks, - torch.cuda.current_stream(), + stream, ) @@ -470,7 +485,10 @@ def _s2_args_std( n_in, k_in, blocks, + stream=None, ): + if stream is None: + stream = torch.cuda.current_stream() return ( target, a, @@ -485,7 +503,7 @@ def _s2_args_std( n_in, k_in, blocks, - torch.cuda.current_stream(), + stream, ) @@ -589,6 +607,7 @@ def flydsl_moe_stage1( _need_fp8 = out_dtype == "fp8" _fuse_any_quant = _need_fp4 or _need_fp8 _base_out_dtype = "bf16" if _fuse_any_quant else out_dtype + dtypes = _get_dtypes() if _need_fp4: torch_out_dtype = dtypes.fp4x2 diff --git a/aiter/utility/dtypes.py b/aiter/utility/dtypes.py index 5fbe00f9be..25ddecb694 100644 --- a/aiter/utility/dtypes.py +++ b/aiter/utility/dtypes.py @@ -2,7 +2,6 @@ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch from ..jit.utils.chip_info import get_gfx -from ..jit.core import compile_ops from ..ops.enum import QuantType, ActivationType from .aiter_types import aiter_dtypes, aiter_tensor_t import argparse diff --git a/pyproject.toml b/pyproject.toml index 3066a96157..082b4c37cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "psutil", "ninja", "pandas", - "flydsl==0.1.4" + "flydsl==0.1.4.2" ] [tool.setuptools_scm] diff --git a/requirements.txt b/requirements.txt index 4e4c7c15cb..8a863d7694 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ pyyaml einops pybind11>=3.0.1 ninja -flydsl==0.1.4 +flydsl==0.1.4.2 diff --git a/setup.py b/setup.py index b7c496880a..bf78f9b177 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ OPT_COMPILER_CONFIG = os.path.join(this_dir, "aiter", "jit", "optCompilerConfig.json") PACKAGE_NAME = "amd-aiter" -FLYDSL_VERSION = "flydsl==0.1.4" +FLYDSL_VERSION = "flydsl==0.1.4.2" BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") PREBUILD_KERNELS = int(os.environ.get("PREBUILD_KERNELS", 0)) @@ -315,18 +315,28 @@ def build_one_module(one_opt_args): prebuid_thread_num = min(prebuid_thread_num, getMaxJobs()) os.environ["PREBUILD_THREAD_NUM"] = str(prebuid_thread_num) + # --- FlyDSL AOT pre-compilation (MOE + GEMM in parallel, before CK) --- + _prev_aot_import = os.environ.get("AITER_AOT_IMPORT") + os.environ["AITER_AOT_IMPORT"] = "1" + try: + from aiter.aot.flydsl.common import start_aot, wait_aot + + flydsl_cache_dir = os.path.join(this_dir, "aiter", "jit", "flydsl_cache") + _flydsl_pool, _flydsl_futures = start_aot(flydsl_cache_dir) + wait_aot(_flydsl_pool, _flydsl_futures) + finally: + if _prev_aot_import is None: + os.environ.pop("AITER_AOT_IMPORT", None) + else: + os.environ["AITER_AOT_IMPORT"] = _prev_aot_import + + # --- CK kernel builds --- with ThreadPoolExecutor(max_workers=prebuid_thread_num) as executor: list(executor.map(build_one_module, all_opts_args_build)) # Retune GEMM shapes on the live GPU after the main build phase. - # Each requested module's tune script benchmarks all CSV shapes and - # writes results tagged with the live GPU's (gfx, cu_num) back to - # the source CSV, then rebuilds the inference .so. if PRETUNE_MODULES: - # Import directly from the file to avoid triggering aiter/__init__.py, - # which would try to load module_aiter_core before it is registered. - sys.path.insert(0, os.path.join(this_dir, "aiter", "utility")) - from pretune import run_pretune_modules # noqa: E402 + from aiter.utility.pretune import run_pretune_modules # noqa: E402 cfg_path = OPT_COMPILER_CONFIG with open(cfg_path, "r", encoding="utf-8") as _f: @@ -340,76 +350,6 @@ def build_one_module(one_opt_args): repo_dir=this_dir, ) - # --- FlyDSL AOT pre-compilation --- - try: - flydsl_cache_dir = os.path.join(this_dir, "aiter", "jit", "flydsl_cache") - os.makedirs(flydsl_cache_dir, exist_ok=True) - os.environ["FLYDSL_RUNTIME_CACHE_DIR"] = flydsl_cache_dir - - # setup.py loads `jit.core` via sys.path (line 134-135). - # Map those modules into the `aiter.*` namespace so that - # `import aiter.jit.core` reuses the same instances. - for _name in list(sys.modules): - if _name == "jit" or _name.startswith("jit."): - _pkg = f"aiter.{_name}" - if _pkg not in sys.modules: - sys.modules[_pkg] = sys.modules[_name] - - from aiter.aot.flydsl.common import collect_aot_jobs - - def _run_flydsl_aot(label, default_csvs, parse_csv, compile_one_config): - jobs = collect_aot_jobs( - default_csvs, - parse_csv, - on_missing_csv=lambda csv_path: print( - f"[aiter] {label}: CSV not found: {csv_path}" - ), - ) - if jobs: - print( - f"[aiter] {label}: {len(jobs)} kernels to compile " - f"(cache: {flydsl_cache_dir})" - ) - results = [] - for job in jobs: - results.append(compile_one_config(**job)) - ok = sum( - 1 for result in results if result["compile_time"] is not None - ) - fail = len(results) - ok - print(f"[aiter] {label}: compiled {ok} ok, {fail} failed") - - from aiter.aot.flydsl.moe import ( - DEFAULT_CSVS as MOE_DEFAULT_CSVS, - compile_one_config as compile_moe_one_config, - parse_csv as parse_moe_csv, - ) - - _run_flydsl_aot( - "FlyDSL MoE AOT", - MOE_DEFAULT_CSVS, - parse_moe_csv, - compile_moe_one_config, - ) - - from aiter.aot.flydsl.gemm import ( - DEFAULT_CSVS as GEMM_DEFAULT_CSVS, - compile_one_config as compile_gemm_one_config, - parse_csv as parse_gemm_csv, - ) - - _run_flydsl_aot( - "FlyDSL GEMM AOT", - GEMM_DEFAULT_CSVS, - parse_gemm_csv, - compile_gemm_one_config, - ) - except Exception as e: - import traceback - - traceback.print_exc() - print(f"[aiter] FlyDSL AOT skipped: {e}") - class NinjaBuildExtension(build_ext): """Custom build_ext that defers expensive operations until run() is called."""