diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 5bbdc48a4..8d9503739 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -37,6 +37,7 @@ from tilelang import env from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult +from tilelang.utils.language import get_prim_func_name from tilelang.autotuner.capture import get_autotune_inputs from tilelang.utils.target import determine_target from tilelang import __version__ @@ -332,11 +333,15 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): if env.is_cache_enabled() and not env.is_autotune_cache_disabled(): # First check in-memory cache if key in self._memory_cache: + # Include PrimFunc name when hitting autotuner memory cache + cached_result = self._memory_cache[key] + prim = getattr(cached_result, "func", None) + kernel_name = get_prim_func_name(prim, "") logger.warning( - "Found kernel in memory cache. For better performance," - " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel." + "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.", + kernel_name, ) - return self._memory_cache[key] + return cached_result # Then check disk cache result = self._load_result_from_disk(key) diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index 4fbe2dce5..58295406e 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -16,6 +16,7 @@ from tvm.tir import PrimFunc from tvm.runtime import Executable from tilelang.engine.param import KernelParam +from tilelang.utils.language import get_prim_func_name from tilelang import env from tilelang.jit import JITKernel from tilelang import __version__ @@ -179,13 +180,16 @@ def cached( with self._lock: # First check in-memory cache if key in self._memory_cache: + # Include kernel name for easier debugging when hitting memory cache + kernel_name = get_prim_func_name(func, "") self.logger.warning( - "Found kernel in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching." + "Found kernel '%s' in memory cache. For better performance, consider using `@tilelang.jit` instead of direct kernel caching.", + kernel_name, ) return self._memory_cache[key] if verbose: - self.logger.debug(f"Checking disk cache for kernel {func.attrs['global_symbol']}") + self.logger.debug(f"Checking disk cache for kernel {get_prim_func_name(func, '')}") # Then check disk cache kernel = self._load_kernel_from_disk( @@ -193,13 +197,13 @@ def cached( ) if kernel is not None: if verbose: - self.logger.debug(f"Found kernel in disk cache for {func.attrs['global_symbol']}") + self.logger.debug(f"Found kernel in disk cache for {get_prim_func_name(func, '')}") # Populate memory cache with disk result self._memory_cache[key] = kernel return kernel if verbose: - self.logger.debug(f"No cached kernel for {func.attrs['global_symbol']}") + self.logger.debug(f"No cached kernel for {get_prim_func_name(func, '')}") # Compile kernel if cache miss; leave critical section kernel = JITKernel( func, diff --git a/tilelang/utils/__init__.py b/tilelang/utils/__init__.py index a713df8e0..df0c71c2e 100644 --- a/tilelang/utils/__init__.py +++ b/tilelang/utils/__init__.py @@ -16,5 +16,6 @@ is_full_region, # noqa: F401 to_buffer_region, # noqa: F401 get_buffer_region_from_load, # noqa: F401 + get_prim_func_name, # noqa: F401 ) from .deprecated import deprecated # noqa: F401 diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 584e9998d..ea8e58804 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -478,3 +478,27 @@ def is_full_region(buffer_region: BufferRegion) -> bool: if not expr_equal(r.extent, dim): return False return True + + +def get_prim_func_name(func: PrimFunc | None, default: str | None = None) -> str | None: + """ + Extract a human‑readable function name from a TVM PrimFunc. + + Prefer the `global_symbol` attribute set on the PrimFunc. If it is missing + (e.g., private PrimFunc without a global symbol), return the provided + `default` value. + + Args: + func: TVM PrimFunc instance or None. + default: Fallback name to return when no name can be determined. + + Returns: + The function name as a string, or `default` when unavailable. + """ + if func is None: + return default + try: + name = func.attrs["global_symbol"] + return str(name) if name is not None else default + except Exception: + return default