Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dic
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters

def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
def generate_cache_key(self, parameters: dict[str, Any],
extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process.
"""

Expand All @@ -261,6 +262,7 @@ def _normalize_param(value):
key_data = {
"version": __version__,
"op_parameters": tuple(op_parameters),
"extra_parameters": extra_parameters,
"func_source": func_source,
"configs": self.configs,
"compile_args": hash(self.compile_args),
Expand Down Expand Up @@ -293,10 +295,28 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
sig = inspect.signature(self.fn)
parameters = sig.parameters

# NOTE(chaofan): We need to extract some parameters from the closure.
# Consider the case:
# def gemm(M, N, K):
# def kernel(...)
# If we only extract source, M/N/K will be symbolic and there will be cache problem.
extra_parameters: dict[str, Any] = {}
cells = self.fn.__closure__
var_names = self.fn.__code__.co_freevars
if cells is not None:
assert len(var_names) == len(cells), "Number of free variables does not match"
for var_name, cell in zip(var_names, cells):
if var_name in parameters:
continue
# Cell content must be serializable
assert isinstance(cell.cell_contents, (int, float, str, bool, type(None))), \
f"Cell contents {cell.cell_contents} is not serializable: {type(cell.cell_contents)}"
extra_parameters[var_name] = cell.cell_contents

if isinstance(self.configs, Callable):
self.configs = self.configs(*self._kernel_parameters)

key = self.generate_cache_key(parameters)
key = self.generate_cache_key(parameters, extra_parameters)

with self._lock:
if env.is_cache_enabled():
Expand Down
Loading