diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 8f26a59c3..30acd879e 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -7,7 +7,6 @@ import tilelang.language as T torch.manual_seed(0) -tilelang.disable_cache() def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -197,7 +196,6 @@ def get_configs(): return configs -@autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit(out_idx=[-2, -1]) def flashattn( batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128 diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 95595d5dc..e54286b41 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -9,6 +9,7 @@ import tilelang from tilelang import tvm as tvm +from tilelang import env from tilelang.jit import JITImpl from tilelang.jit.kernel import JITKernel from tvm.tir import PrimFunc, Var @@ -35,7 +36,6 @@ import traceback from pathlib import Path -from tilelang import env from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult from tilelang.utils.language import get_prim_func_name from tilelang.autotuner.capture import get_autotune_inputs @@ -143,25 +143,40 @@ def from_kernel(cls, kernel: Callable, configs): def set_compile_args( self, out_idx: list[int] | int | None = None, - target: Literal["auto", "cuda", "hip", "metal"] = "auto", - execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] = "auto", - target_host: str | Target = None, - verbose: bool = False, + target: Literal["auto", "cuda", "hip", "metal"] | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = None, + target_host: str | Target | None = None, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, ): """Set compilation arguments for the auto-tuner. Args: out_idx: List of output tensor indices. - target: Target platform. - execution_backend: Execution backend to use for kernel execution. + target: Target platform. If None, reads from TILELANG_TARGET environment variable (defaults to "auto"). + execution_backend: Execution backend to use for kernel execution. If None, reads from + TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). target_host: Target host for cross-compilation. - verbose: Whether to enable verbose output. + verbose: Whether to enable verbose output. If None, reads from + TILELANG_VERBOSE environment variable (defaults to False). pass_configs: Additional keyword arguments to pass to the Compiler PassContext. + Environment Variables: + TILELANG_TARGET: Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". + TILELANG_EXECUTION_BACKEND: Default execution backend. Defaults to "auto". + TILELANG_VERBOSE: Set to "1", "true", "yes", or "on" to enable verbose compilation by default. + Returns: AutoTuner: Self for method chaining. """ + # Apply environment variable defaults if parameters are not explicitly set + if target is None: + target = env.get_default_target() + if execution_backend is None: + execution_backend = env.get_default_execution_backend() + if verbose is None: + verbose = env.get_default_verbose() + # Normalize target to a concrete TVM Target and resolve execution backend t = Target(determine_target(target)) from tilelang.jit.execution_backend import resolve_execution_backend diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index 9774e8a9c..b0575e6c0 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -17,10 +17,10 @@ def cached( func: PrimFunc = None, out_idx: list[int] = None, *args, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = "auto", - verbose: bool | None = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = None, + verbose: bool | None = None, pass_configs: dict | None = None, compile_flags: list[str] | str | None = None, ) -> JITKernel: diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index ded861a3d..30e94aea1 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -121,26 +121,49 @@ def cached( func: PrimFunc = None, out_idx: list[int] = None, *args, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto", - verbose: bool = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + verbose: bool | None = None, pass_configs: dict = None, compile_flags: list[str] | str | None = None, ) -> JITKernel: """ Caches and reuses compiled kernels to avoid redundant compilation. + This is the ONLY place where environment variable processing, target normalization, + and execution backend resolution should happen. All compilation paths go through here. + Args: func: Function to be compiled or a prepared PrimFunc out_idx: Indices specifying which outputs to return - target: Compilation target platform + target: Compilation target platform (None = read from TILELANG_TARGET env var) target_host: Host target platform + execution_backend: Execution backend (None = read from TILELANG_EXECUTION_BACKEND) + verbose: Enable verbose output (None = read from TILELANG_VERBOSE) *args: Arguments passed to func Returns: JITKernel: The compiled kernel, either freshly compiled or from cache + + Environment Variables + --------------------- + TILELANG_TARGET : str + Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". + TILELANG_EXECUTION_BACKEND : str + Default execution backend. Defaults to "auto". + TILELANG_VERBOSE : str + Set to "1", "true", "yes", or "on" to enable verbose compilation by default. """ + # Apply environment variable defaults if parameters are not explicitly set + # This is the SINGLE source of truth for env var processing + if target is None: + target = env.get_default_target() + if execution_backend is None: + execution_backend = env.get_default_execution_backend() + if verbose is None: + verbose = env.get_default_verbose() + # Normalize target and resolve execution backend before proceeding from tilelang.utils.target import determine_target as _determine_target from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target diff --git a/tilelang/env.py b/tilelang/env.py index 0583cd4cf..47ed4c3dc 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -245,6 +245,12 @@ class Environment: TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit + # Compilation defaults (for jit, autotune, compile) + # These allow overriding default compilation parameters via environment variables + TILELANG_DEFAULT_TARGET = EnvVar("TILELANG_TARGET", "auto") + TILELANG_DEFAULT_EXECUTION_BACKEND = EnvVar("TILELANG_EXECUTION_BACKEND", "auto") + TILELANG_DEFAULT_VERBOSE = EnvVar("TILELANG_VERBOSE", "0") + # TVM integration SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0") TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None) @@ -289,6 +295,18 @@ def use_gemm_v1(self) -> bool: """ return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on") + def get_default_target(self) -> str: + """Get default compilation target from environment.""" + return self.TILELANG_DEFAULT_TARGET + + def get_default_execution_backend(self) -> str: + """Get default execution backend from environment.""" + return self.TILELANG_DEFAULT_EXECUTION_BACKEND + + def get_default_verbose(self) -> bool: + """Get default verbose flag from environment.""" + return self.TILELANG_DEFAULT_VERBOSE.lower() in ("1", "true", "yes", "on") + # Instantiate as a global configuration object env = Environment() diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 23163b1fe..502ad4b5b 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -29,7 +29,6 @@ from tvm.target import Target from tilelang.jit.kernel import JITKernel -from tilelang.utils.target import determine_target from tilelang.cache import cached from os import path, makedirs from logging import getLogger @@ -49,15 +48,16 @@ def compile( func: PrimFunc[_KP, _T] = None, out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto", - target: str | Target = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + target: str | Target | None = None, target_host: str | Target | None = None, - verbose: bool = False, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | str | None = None, ) -> JITKernel[_KP, _T]: """ Compile the given TileLang PrimFunc with TVM and build a JITKernel. + Parameters ---------- func : tvm.tir.PrimFunc, optional @@ -65,17 +65,28 @@ def compile( out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional - Execution backend to use for kernel execution. Use "auto" to pick a sensible - default per target (cuda->tvm_ffi, metal->torch, others->cython). + Execution backend to use for kernel execution. If None, reads from + TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). target : Union[str, Target], optional - Compilation target, either as a string or a TVM Target object (default: "auto"). + Compilation target, either as a string or a TVM Target object. If None, reads from + TILELANG_TARGET environment variable (defaults to "auto"). target_host : Union[str, Target], optional Target host for cross-compilation (default: None). verbose : bool, optional - Whether to enable verbose output (default: False). + Whether to enable verbose output. If None, reads from + TILELANG_VERBOSE environment variable (defaults to False). pass_configs : dict, optional Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. + + Environment Variables + --------------------- + TILELANG_TARGET : str + Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". + TILELANG_EXECUTION_BACKEND : str + Default execution backend. Defaults to "auto". + TILELANG_VERBOSE : str + Set to "1", "true", "yes", or "on" to enable verbose compilation by default. """ assert isinstance(func, PrimFunc), f"target function must be a PrimFunc but got {type(func)}" @@ -85,24 +96,6 @@ def compile( raise ValueError("Out index conflict: out_idx is specified and prim_func have returned `T.empty` tensors") out_idx = func.out_idx_override or out_idx - # This path is not a performance critical path, so we can afford to convert the target. - target = Target(determine_target(target)) - - # Resolve execution backend (handles aliases, auto, validation per target) - requested_backend = execution_backend - from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target - - execution_backend = resolve_execution_backend(requested_backend, target) - if verbose: - allowed_now = allowed_backends_for_target(target, include_unavailable=False) - logger.info( - "Execution backend resolved -> '%s' (requested='%s', target='%s', allowed: %s)", - execution_backend, - requested_backend, - target.kind.name, - ", ".join(sorted(allowed_now)), - ) - return cached( func=func, out_idx=out_idx, @@ -118,17 +111,18 @@ def compile( def par_compile( funcs: Iterable[PrimFunc[_KP, _T]], out_idx: list[int] | int | None = None, - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto", - target: str | Target = "auto", + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None, + target: str | Target | None = None, target_host: str | Target | None = None, - verbose: bool = False, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, compile_flags: list[str] | str | None = None, - num_workers: int = None, + num_workers: int | None = None, ignore_error: bool = False, ) -> list[JITKernel[_KP, _T]]: """ Parallel compile multiple TileLang PrimFunc with TVM and build JITKernels. + Parameters ---------- funcs : Iterable[tvm.tir.PrimFunc] @@ -136,18 +130,30 @@ def par_compile( out_idx : Union[List[int], int], optional Index(es) of the output tensors to return (default: None). execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional - Execution backend to use for kernel execution. Use "auto" to pick a sensible - default per target (cuda->tvm_ffi, metal->torch, others->cython). + Execution backend to use for kernel execution. If None, reads from + TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). target : Union[str, Target], optional - Compilation target, either as a string or a TVM Target object (default: "auto"). + Compilation target, either as a string or a TVM Target object. If None, reads from + TILELANG_TARGET environment variable (defaults to "auto"). target_host : Union[str, Target], optional Target host for cross-compilation (default: None). verbose : bool, optional - Whether to enable verbose output (default: False). + Whether to enable verbose output. If None, reads from + TILELANG_VERBOSE environment variable (defaults to False). pass_configs : dict, optional Additional keyword arguments to pass to the Compiler PassContext. Refer to `tilelang.transform.PassConfigKey` for supported options. + + Environment Variables + --------------------- + TILELANG_TARGET : str + Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". + TILELANG_EXECUTION_BACKEND : str + Default execution backend. Defaults to "auto". + TILELANG_VERBOSE : str + Set to "1", "true", "yes", or "on" to enable verbose compilation by default. """ + with concurrent.futures.ThreadPoolExecutor(num_workers, "tl-par-comp") as executor: futures = [] future_map = {} @@ -256,10 +262,10 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]): """ out_idx: list[int] | int | None - execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] - target: str | Target - target_host: str | Target - verbose: bool + execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None + target: str | Target | None + target_host: str | Target | None + verbose: bool | None pass_configs: dict[str, Any] | None debug_root_path: str | None compile_flags: list[str] | str | None @@ -412,7 +418,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret: tune_params = kwargs.pop("__tune_params", {}) - kernel = self._kernel_cache.get(key, None) + kernel = self._kernel_cache.get(key) if kernel is None: kernel = self.compile(*args, **kwargs, **tune_params) self._kernel_cache[key] = kernel @@ -435,10 +441,10 @@ def jit(func: Callable[_P, PrimFunc[_KP, _T]]) -> JITImpl[_P, _KP, _T, JITKernel def jit( *, # Indicates subsequent arguments are keyword-only out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: ExecutionBackend | None = None, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, @@ -449,10 +455,10 @@ def jit( # This is the new public interface func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: ExecutionBackend | None = None, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, @@ -470,19 +476,30 @@ def jit( # This is the new public interface If using `@tilelang.jit` directly on a function, this argument is implicitly the function to be decorated (and `out_idx` will be `None`). target : Union[str, Target], optional - Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". + Compilation target for TVM (e.g., "cuda", "llvm"). If None, reads from + TILELANG_TARGET environment variable (defaults to "auto"). target_host : Union[str, Target], optional Target host for cross-compilation. Defaults to None. execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional - Backend for kernel execution and argument passing. Use "auto" to pick a sensible - default per target (cuda->tvm_ffi, metal->torch, others->cython). + Backend for kernel execution and argument passing. If None, reads from + TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto"). verbose : bool, optional - Enables verbose logging during compilation. Defaults to False. + Enables verbose logging during compilation. If None, reads from + TILELANG_VERBOSE environment variable (defaults to False). pass_configs : Optional[Dict[str, Any]], optional Configurations for TVM's pass context. Defaults to None. debug_root_path : Optional[str], optional Directory to save compiled kernel source for debugging. Defaults to None. + Environment Variables + --------------------- + TILELANG_TARGET : str + Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto". + TILELANG_EXECUTION_BACKEND : str + Default execution backend. Defaults to "auto". + TILELANG_VERBOSE : str + Set to "1", "true", "yes", or "on" to enable verbose compilation by default. + Returns ------- Callable @@ -524,10 +541,10 @@ def lazy_jit(func: Callable[_KP, _T]) -> JITImpl[_KP, _KP, _T, _T]: ... def lazy_jit( *, out_idx: Any = None, - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: ExecutionBackend | None = None, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, @@ -537,14 +554,21 @@ def lazy_jit( def lazy_jit( func: Callable[_P, _T] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only - target: str | Target = "auto", - target_host: str | Target = None, - execution_backend: ExecutionBackend = "auto", - verbose: bool = False, + target: str | Target | None = None, + target_host: str | Target | None = None, + execution_backend: ExecutionBackend | None = None, + verbose: bool | None = None, pass_configs: dict[str, Any] | None = None, debug_root_path: str | None = None, compile_flags: list[str] | str | None = None, ): + """ + Lazy JIT compiler decorator - returns the kernel object on first call, then executes it. + + Supports environment variable defaults for target, execution_backend, and verbose. + See `jit` documentation for parameter details and environment variables. + """ + compile_args = dict( out_idx=None, execution_backend=execution_backend,