Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/flash_decoding/example_gqa_decode_varlen_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import tilelang.language as T

torch.manual_seed(0)
tilelang.disable_cache()


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -197,7 +196,6 @@ def get_configs():
return configs


@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1])
def flashattn(
batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
Expand Down
31 changes: 23 additions & 8 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import tilelang
from tilelang import tvm as tvm
from tilelang import env
from tilelang.jit import JITImpl
from tilelang.jit.kernel import JITKernel
from tvm.tir import PrimFunc, Var
Expand All @@ -35,7 +36,6 @@
import traceback
from pathlib import Path

from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.utils.language import get_prim_func_name
from tilelang.autotuner.capture import get_autotune_inputs
Expand Down Expand Up @@ -143,25 +143,40 @@ def from_kernel(cls, kernel: Callable, configs):
def set_compile_args(
self,
out_idx: list[int] | int | None = None,
target: Literal["auto", "cuda", "hip", "metal"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] = "auto",
target_host: str | Target = None,
verbose: bool = False,
target: Literal["auto", "cuda", "hip", "metal"] | None = None,
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = None,
target_host: str | Target | None = None,
verbose: bool | None = None,
pass_configs: dict[str, Any] | None = None,
):
"""Set compilation arguments for the auto-tuner.

Args:
out_idx: List of output tensor indices.
target: Target platform.
execution_backend: Execution backend to use for kernel execution.
target: Target platform. If None, reads from TILELANG_TARGET environment variable (defaults to "auto").
execution_backend: Execution backend to use for kernel execution. If None, reads from
TILELANG_EXECUTION_BACKEND environment variable (defaults to "auto").
target_host: Target host for cross-compilation.
verbose: Whether to enable verbose output.
verbose: Whether to enable verbose output. If None, reads from
TILELANG_VERBOSE environment variable (defaults to False).
pass_configs: Additional keyword arguments to pass to the Compiler PassContext.

Environment Variables:
TILELANG_TARGET: Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto".
TILELANG_EXECUTION_BACKEND: Default execution backend. Defaults to "auto".
TILELANG_VERBOSE: Set to "1", "true", "yes", or "on" to enable verbose compilation by default.

Returns:
AutoTuner: Self for method chaining.
"""
# Apply environment variable defaults if parameters are not explicitly set
if target is None:
target = env.get_default_target()
if execution_backend is None:
execution_backend = env.get_default_execution_backend()
if verbose is None:
verbose = env.get_default_verbose()

# Normalize target to a concrete TVM Target and resolve execution backend
t = Target(determine_target(target))
from tilelang.jit.execution_backend import resolve_execution_backend
Expand Down
8 changes: 4 additions & 4 deletions tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def cached(
func: PrimFunc = None,
out_idx: list[int] = None,
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = "auto",
verbose: bool | None = False,
target: str | Target | None = None,
target_host: str | Target | None = None,
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = None,
verbose: bool | None = None,
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
Expand Down
33 changes: 28 additions & 5 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,26 +121,49 @@ def cached(
func: PrimFunc = None,
out_idx: list[int] = None,
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
verbose: bool = False,
target: str | Target | None = None,
target_host: str | Target | None = None,
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] | None = None,
verbose: bool | None = None,
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.

This is the ONLY place where environment variable processing, target normalization,
and execution backend resolution should happen. All compilation paths go through here.

Args:
func: Function to be compiled or a prepared PrimFunc
out_idx: Indices specifying which outputs to return
target: Compilation target platform
target: Compilation target platform (None = read from TILELANG_TARGET env var)
target_host: Host target platform
execution_backend: Execution backend (None = read from TILELANG_EXECUTION_BACKEND)
verbose: Enable verbose output (None = read from TILELANG_VERBOSE)
*args: Arguments passed to func

Returns:
JITKernel: The compiled kernel, either freshly compiled or from cache

Environment Variables
---------------------
TILELANG_TARGET : str
Default compilation target (e.g., "cuda", "llvm"). Defaults to "auto".
TILELANG_EXECUTION_BACKEND : str
Default execution backend. Defaults to "auto".
TILELANG_VERBOSE : str
Set to "1", "true", "yes", or "on" to enable verbose compilation by default.
"""
# Apply environment variable defaults if parameters are not explicitly set
# This is the SINGLE source of truth for env var processing
if target is None:
target = env.get_default_target()
if execution_backend is None:
execution_backend = env.get_default_execution_backend()
if verbose is None:
verbose = env.get_default_verbose()

# Normalize target and resolve execution backend before proceeding
from tilelang.utils.target import determine_target as _determine_target
from tilelang.jit.execution_backend import resolve_execution_backend, allowed_backends_for_target
Expand Down
18 changes: 18 additions & 0 deletions tilelang/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,12 @@ class Environment:
TILELANG_AUTO_TUNING_CPU_COUNTS = EnvVar("TILELANG_AUTO_TUNING_CPU_COUNTS", "-1") # -1 means auto
TILELANG_AUTO_TUNING_MAX_CPU_COUNT = EnvVar("TILELANG_AUTO_TUNING_MAX_CPU_COUNT", "-1") # -1 means no limit

# Compilation defaults (for jit, autotune, compile)
# These allow overriding default compilation parameters via environment variables
TILELANG_DEFAULT_TARGET = EnvVar("TILELANG_TARGET", "auto")
TILELANG_DEFAULT_EXECUTION_BACKEND = EnvVar("TILELANG_EXECUTION_BACKEND", "auto")
TILELANG_DEFAULT_VERBOSE = EnvVar("TILELANG_VERBOSE", "0")

# TVM integration
SKIP_LOADING_TILELANG_SO = EnvVar("SKIP_LOADING_TILELANG_SO", "0")
TVM_IMPORT_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
Expand Down Expand Up @@ -289,6 +295,18 @@ def use_gemm_v1(self) -> bool:
"""
return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on")

def get_default_target(self) -> str:
"""Get default compilation target from environment."""
return self.TILELANG_DEFAULT_TARGET

def get_default_execution_backend(self) -> str:
"""Get default execution backend from environment."""
return self.TILELANG_DEFAULT_EXECUTION_BACKEND

def get_default_verbose(self) -> bool:
"""Get default verbose flag from environment."""
return self.TILELANG_DEFAULT_VERBOSE.lower() in ("1", "true", "yes", "on")


# Instantiate as a global configuration object
env = Environment()
Expand Down
Loading
Loading