diff --git a/python/tvm_ffi/error.py b/python/tvm_ffi/error.py index 015fe805..4b4a8c26 100644 --- a/python/tvm_ffi/error.py +++ b/python/tvm_ffi/error.py @@ -121,8 +121,18 @@ def append_traceback( The new traceback with the appended frame. """ - frame = self._create_frame(filename, lineno, func) - return types.TracebackType(tb, frame, frame.f_lasti, lineno) + + # This approach avoids binding the created frame object to a local variable + # in `append_traceback`, which would create a reference cycle. By using a + # nested function, the frame object is a temporary that is not held by + # the locals of `append_traceback`. See the diagram in `_with_append_backtrace` + # and PR #327 for more details. + def create( + tb: types.TracebackType | None, frame: types.FrameType, lineno: int + ) -> types.TracebackType: + return types.TracebackType(tb, frame, frame.f_lasti, lineno) + + return create(tb, self._create_frame(filename, lineno, func), lineno) _TRACEBACK_MANAGER = TracebackManager() @@ -130,10 +140,47 @@ def append_traceback( def _with_append_backtrace(py_error: BaseException, backtrace: str) -> BaseException: """Append the backtrace to the py_error and return it.""" + # We manually delete py_error and tb to avoid reference cycle, making it faster to gc the locals inside the frame + # please see pull request #327 for more details + # + # Memory Cycle Diagram: + # + # [Stack Frames] [Heap Objects] + # +-------------------+ + # | outside functions | -----------------------> [ Tensor ] + # +-------------------+ (Held by cycle, slow to free) + # ^ + # | f_back + # +-------------------+ locals py_error + # | py_error (this) | -----+--------------> [ BaseException ] + # +-------------------+ | | + # ^ | | (with_traceback) + # | f_back | v + # +-------------------+ +--------------> [ Traceback Obj ] + # | append_traceback | tb | + # +-------------------+ | + # ^ | + # | f_back | + # +-------------------+ | + # | _create_frame | | + # +-------------------+ | + # ^ | + # | f_back | + # +-------------------+ | + # | _get_frame | <----------------------------+ + # +-------------------+ (Cycle closes here) tb = py_error.__traceback__ - for filename, lineno, func in _parse_backtrace(backtrace): - tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) - return py_error.with_traceback(tb) + try: + for filename, lineno, func in _parse_backtrace(backtrace): + tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func) + return py_error.with_traceback(tb) + finally: + # We explicitly break the reference cycle here. The `finally` block is + # executed just before the function returns, after the `return` expression + # in the `try` block has been evaluated. Deleting `py_error` and `tb` + # here ensures they are not held by this function's frame's locals, + # which resolves the cycle. + del py_error, tb def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str: diff --git a/tests/python/test_error.py b/tests/python/test_error.py index dd94cf39..068fb3fd 100644 --- a/tests/python/test_error.py +++ b/tests/python/test_error.py @@ -16,6 +16,8 @@ # under the License. +import gc +import weakref from typing import NoReturn import pytest @@ -113,3 +115,39 @@ def raise_cxx_error() -> None: ffi_error2 = fecho(e) assert ffi_error1.backtrace.find("raise_cxx_error") != -1 assert ffi_error2.backtrace.find("raise_cxx_error") != -1 + + +def test_error_no_cyclic_reference() -> None: + # This test case ensures that when an error is raised from C++ side, + # there is no cyclic reference that slows down the garbage collection. + # Please see `_with_append_backtrace` in error.py + + # temporarily disable gc + gc.disable() + + try: + # We should create a class as a probe to detect gc activity + # beacuse weakref doesn't support list, dict or other trivial types + class SampleObject: ... + + # trigger a C++ side KeyError by accessing a non-existent key + def trigger_cpp_side_error() -> None: + try: + tmp_map: tvm_ffi.Map = tvm_ffi.Map(dict()) + tmp_map["a"] + except KeyError: + pass + + def may_create_cyclic_reference() -> weakref.ReferenceType: + obj = SampleObject() + trigger_cpp_side_error() + return weakref.ref(obj) + + wref = may_create_cyclic_reference() + + # if the object is not collected, wref() will return the object + assert wref() is None, "Cyclic reference occurs inside error handling pipeline" + + finally: + # re-enable gc whenever exception occurs + gc.enable()