diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py index 47e71fe50..d2c5469bb 100644 --- a/examples/autodd/tilelang_buggy.py +++ b/examples/autodd/tilelang_buggy.py @@ -73,14 +73,10 @@ def get_grid_size(self): return grid_x, grid_y def get_shared_memory_size(self): - return get_memory_requirements( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def validate(self): - return validate_parameters( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def create_reference_output(a, b, activation="relu"): @@ -107,6 +103,7 @@ def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): # Benchmark import time + start = time.time() for _ in range(num_iters): _ = a @ b diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py index 2135f6fce..3dc88f992 100644 --- a/examples/autodd/tilelang_minimized_expected.py +++ b/examples/autodd/tilelang_minimized_expected.py @@ -13,7 +13,6 @@ class MatmulConfig: - def __init__(self, *args, **kwargs): self.M = 1 self.N = 1 @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs): def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): - @T.prim_func def matmul_kernel(): with T.Kernel(): @@ -45,7 +43,7 @@ def main(*args, **kwargs): try: run_kernel(config) except Exception as e: - print(f'{e}') + print(f"{e}") main() diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index fbd950262..6b00a3426 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -1059,19 +1059,19 @@ def parse_args(self, *args, **kwargs): tir_temp = self.p1_cache.get(p1_key, None) if tir_temp is None: # mode should be set by JITImpl before calling parse_args - tir_temp = self._build_tir_template(*args, **kwargs) + tir_temp = self._build_tir_template(**kwargs) self.p1_cache[p1_key] = tir_temp p2_key = tir_temp._parse_phase2_key(**tensor_args) return (p1_key, p2_key), tensor_args def get_tir(self, *args, **kwargs): - (p1_key, _), tensor_args = self.parse_args(*args, **kwargs) + p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs) if p1_key not in self.p1_cache: # in legacy gemm, we use lazy tir template to build the tir - tir_temp = self._build_tir_template(*args, **kwargs) + tir_temp = self._build_tir_template(**kwargs) self.p1_cache[p1_key] = tir_temp - return tir_temp.get_tir(**tensor_args) - return self.p1_cache[p1_key].get_tir(**tensor_args) + return tir_temp.get_tir(**tensor_args, **kwargs) + return self.p1_cache[p1_key].get_tir(**tensor_args, **kwargs) def __call__(self, *args, **kwargs): return self.get_tir(*args, **kwargs)