diff --git a/testing/python/jit/test_tilelang_jit_gemm_cython.py b/testing/python/jit/test_tilelang_jit_gemm_cython.py index 546617583..04c71db9d 100644 --- a/testing/python/jit/test_tilelang_jit_gemm_cython.py +++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py @@ -52,68 +52,6 @@ def main( return main -def run_gemm( - M, - N, - K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - block_M, - block_N, - block_K, - num_stages=3, - num_threads=128, -): - program = matmul( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - dtypeAccum, - num_stages, - num_threads, - ) - - stramp = "&*(XS)" - - @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True) - def tilelang_callback_cuda_postproc(code, _): - code = f"// {stramp}\n" + code - return code - - matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython") - - kernel_source = matmul_kernel.get_kernel_source() - - assert stramp in kernel_source, f"Expected {stramp} in the kernel source" - - -def test_gemm_f16f16f16_nn(): - run_gemm( - 512, - 1024, - 768, - False, - False, - T.float16, - T.float16, - T.float16, - 128, - 256, - 32, - 2, - ) - - def matmu_jit_kernel( M, N, diff --git a/testing/python/language/test_tilelang_capture.py b/testing/python/language/test_tilelang_capture.py deleted file mode 100644 index 47fec999a..000000000 --- a/testing/python/language/test_tilelang_capture.py +++ /dev/null @@ -1,41 +0,0 @@ -import tilelang.language as T -import tilelang.testing -import torch -import weakref -import gc - - -def test_tilelang_capture(): - @tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, - ) - def get_dummy_kernel(): - @T.prim_func - def dummy_kernel( - a: T.Tensor[(1,), T.float32], - ): - with T.Kernel(1) as _: - a[0] = 1 - - return dummy_kernel - - a = torch.randn(1, 1024) - a_weak = weakref.ref(a) - _kernel = get_dummy_kernel() - del a - torch.cuda.empty_cache() - gc.collect() - torch.cuda.empty_cache() - a_upgrade = a_weak() - assert a_upgrade is None, "A is not garbage collected" - - # use objgraph to debug - # if a_upgrade is not None: - # objgraph.show_backrefs([a_upgrade], max_depth=5) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/language/test_tilelang_memory_leak.py b/testing/python/language/test_tilelang_memory_leak.py new file mode 100644 index 000000000..7da187fa3 --- /dev/null +++ b/testing/python/language/test_tilelang_memory_leak.py @@ -0,0 +1,79 @@ +import tvm_ffi +import tilelang +import tilelang.language as T +import tilelang.testing +import torch +import weakref +import gc + + +def test_tilelang_globals_leak(): + @tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, + ) + def get_dummy_kernel(): + @T.prim_func + def dummy_kernel( + a: T.Tensor[(1,), T.float32], + ): + with T.Kernel(1) as _: + a[0] = 1 + + return dummy_kernel + + a = torch.randn(1, 1024) + a_weak = weakref.ref(a) + _kernel = get_dummy_kernel() + del a + torch.cuda.empty_cache() + gc.collect() + torch.cuda.empty_cache() + a_upgrade = a_weak() + assert a_upgrade is None, "A is not garbage collected" + + # use objgraph to debug + # if a_upgrade is not None: + # objgraph.show_backrefs([a_upgrade], max_depth=5) + + +def test_error_no_cyclic_reference() -> None: + # This test case ensures that when an error is raised from C++ side, + # there is no cyclic reference that slows down the garbage collection. + # Please see `_with_append_backtrace` in error.py + + # temporarily disable gc + gc.disable() + + try: + # We should create a class as a probe to detect gc activity + # because weakref doesn't support list, dict or other trivial types + class SampleObject: ... + + # trigger a C++ side KeyError by accessing a non-existent key + def trigger_cpp_side_error() -> None: + try: + tmp_map = tvm_ffi.Map(dict()) + tmp_map["a"] + except KeyError: + pass + + def may_create_cyclic_reference() -> weakref.ReferenceType: + obj = SampleObject() + trigger_cpp_side_error() + return weakref.ref(obj) + + wref = may_create_cyclic_reference() + + # if the object is not collected, wref() will return the object + assert wref() is None, "Cyclic reference occurs inside error handling pipeline" + + finally: + # re-enable gc whenever exception occurs + gc.enable() + + +if __name__ == "__main__": + tilelang.testing.main()