Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 0 additions & 62 deletions testing/python/jit/test_tilelang_jit_gemm_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,68 +52,6 @@ def main(
return main


def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)

stramp = "&*(XS)"

@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")

kernel_source = matmul_kernel.get_kernel_source()

assert stramp in kernel_source, f"Expected {stramp} in the kernel source"


def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
T.float16,
T.float16,
T.float16,
128,
256,
32,
2,
)


def matmu_jit_kernel(
M,
N,
Expand Down
41 changes: 0 additions & 41 deletions testing/python/language/test_tilelang_capture.py

This file was deleted.

79 changes: 79 additions & 0 deletions testing/python/language/test_tilelang_memory_leak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import tvm_ffi
import tilelang
import tilelang.language as T
import tilelang.testing
import torch
import weakref
import gc


def test_tilelang_globals_leak():
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def get_dummy_kernel():
@T.prim_func
def dummy_kernel(
a: T.Tensor[(1,), T.float32],
):
with T.Kernel(1) as _:
a[0] = 1

return dummy_kernel

a = torch.randn(1, 1024)
a_weak = weakref.ref(a)
_kernel = get_dummy_kernel()
del a
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
a_upgrade = a_weak()
assert a_upgrade is None, "A is not garbage collected"

# use objgraph to debug
# if a_upgrade is not None:
# objgraph.show_backrefs([a_upgrade], max_depth=5)


def test_error_no_cyclic_reference() -> None:
# This test case ensures that when an error is raised from C++ side,
# there is no cyclic reference that slows down the garbage collection.
# Please see `_with_append_backtrace` in error.py

# temporarily disable gc
gc.disable()

try:
# We should create a class as a probe to detect gc activity
# because weakref doesn't support list, dict or other trivial types
class SampleObject: ...

# trigger a C++ side KeyError by accessing a non-existent key
def trigger_cpp_side_error() -> None:
try:
tmp_map = tvm_ffi.Map(dict())
tmp_map["a"]
except KeyError:
pass

def may_create_cyclic_reference() -> weakref.ReferenceType:
obj = SampleObject()
trigger_cpp_side_error()
return weakref.ref(obj)

wref = may_create_cyclic_reference()

# if the object is not collected, wref() will return the object
assert wref() is None, "Cyclic reference occurs inside error handling pipeline"

finally:
# re-enable gc whenever exception occurs
gc.enable()


if __name__ == "__main__":
tilelang.testing.main()
Loading