Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions examples/autodd/tilelang_buggy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ def get_grid_size(self):
return grid_x, grid_y

def get_shared_memory_size(self):
return get_memory_requirements(
self.M, self.N, self.K, self.block_M, self.block_N, self.block_K
)
return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K)

def validate(self):
return validate_parameters(
self.M, self.N, self.K, self.block_M, self.block_N, self.block_K
)
return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K)


def create_reference_output(a, b, activation="relu"):
Expand All @@ -107,6 +103,7 @@ def benchmark_pytorch(M, N, K, num_iters=10, warmup=5):

# Benchmark
import time

start = time.time()
for _ in range(num_iters):
_ = a @ b
Expand Down
4 changes: 1 addition & 3 deletions examples/autodd/tilelang_minimized_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


class MatmulConfig:

def __init__(self, *args, **kwargs):
self.M = 1
self.N = 1
Expand All @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs):


def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs):

@T.prim_func
def matmul_kernel():
with T.Kernel():
Expand All @@ -45,7 +43,7 @@ def main(*args, **kwargs):
try:
run_kernel(config)
except Exception as e:
print(f'{e}')
print(f"{e}")


main()
10 changes: 5 additions & 5 deletions tilelang/language/eager/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,19 +1059,19 @@ def parse_args(self, *args, **kwargs):
tir_temp = self.p1_cache.get(p1_key, None)
if tir_temp is None:
# mode should be set by JITImpl before calling parse_args
tir_temp = self._build_tir_template(*args, **kwargs)
tir_temp = self._build_tir_template(**kwargs)
self.p1_cache[p1_key] = tir_temp
p2_key = tir_temp._parse_phase2_key(**tensor_args)
return (p1_key, p2_key), tensor_args

def get_tir(self, *args, **kwargs):
(p1_key, _), tensor_args = self.parse_args(*args, **kwargs)
p1_key, tensor_args, kwargs = self._parse_phase1_key(*args, **kwargs)
if p1_key not in self.p1_cache:
# in legacy gemm, we use lazy tir template to build the tir
tir_temp = self._build_tir_template(*args, **kwargs)
tir_temp = self._build_tir_template(**kwargs)
self.p1_cache[p1_key] = tir_temp
return tir_temp.get_tir(**tensor_args)
return self.p1_cache[p1_key].get_tir(**tensor_args)
return tir_temp.get_tir(**tensor_args, **kwargs)
return self.p1_cache[p1_key].get_tir(**tensor_args, **kwargs)

def __call__(self, *args, **kwargs):
return self.get_tir(*args, **kwargs)
Expand Down
Loading