Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from tilelang import env
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
from tilelang.utils.language import get_prim_func_name
from tilelang.autotuner.capture import get_autotune_inputs
from tilelang.utils.target import determine_target
from tilelang import __version__
Expand Down Expand Up @@ -332,11 +333,15 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
if env.is_cache_enabled() and not env.is_autotune_cache_disabled():
# First check in-memory cache
if key in self._memory_cache:
# Include PrimFunc name when hitting autotuner memory cache
cached_result = self._memory_cache[key]
prim = getattr(cached_result, "func", None)
kernel_name = get_prim_func_name(prim, "<unknown>")
logger.warning(
"Found kernel in memory cache. For better performance,"
" consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel."
"Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.",
kernel_name,
)
return self._memory_cache[key]
return cached_result

# Then check disk cache
result = self._load_result_from_disk(key)
Expand Down
12 changes: 8 additions & 4 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tvm.tir import PrimFunc
from tvm.runtime import Executable
from tilelang.engine.param import KernelParam
from tilelang.utils.language import get_prim_func_name
from tilelang import env
from tilelang.jit import JITKernel
from tilelang import __version__
Expand Down Expand Up @@ -179,27 +180,30 @@ def cached(
with self._lock:
# First check in-memory cache
if key in self._memory_cache:
# Include kernel name for easier debugging when hitting memory cache
kernel_name = get_prim_func_name(func, "<unknown>")
self.logger.warning(
"Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching."
"Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.",
kernel_name,
)
return self._memory_cache[key]

if verbose:
self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}")
self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '<unknown>')}")

# Then check disk cache
kernel = self._load_kernel_from_disk(
key, norm_target, target_host, out_idx, execution_backend, pass_configs, compile_flags, func, verbose
)
if kernel is not None:
if verbose:
self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}")
self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '<unknown>')}")
# Populate memory cache with disk result
self._memory_cache[key] = kernel
return kernel

if verbose:
self.logger.debug(f"No cached kernel for {func.attrs['global_symbol']}")
self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '<unknown>')}")
# Compile kernel if cache miss; leave critical section
kernel = JITKernel(
func,
Expand Down
1 change: 1 addition & 0 deletions tilelang/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
get_buffer_region_from_load, # noqa: F401
get_prim_func_name, # noqa: F401
)
from .deprecated import deprecated # noqa: F401
24 changes: 24 additions & 0 deletions tilelang/utils/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,27 @@ def is_full_region(buffer_region: BufferRegion) -> bool:
if not expr_equal(r.extent, dim):
return False
return True


def get_prim_func_name(func: PrimFunc | None, default: str | None = None) -> str | None:
"""
Extract a human‑readable function name from a TVM PrimFunc.

Prefer the `global_symbol` attribute set on the PrimFunc. If it is missing
(e.g., private PrimFunc without a global symbol), return the provided
`default` value.

Args:
func: TVM PrimFunc instance or None.
default: Fallback name to return when no name can be determined.

Returns:
The function name as a string, or `default` when unavailable.
"""
if func is None:
return default
try:
name = func.attrs["global_symbol"]
return str(name) if name is not None else default
except Exception:
return default
Loading