Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions python/tvm_ffi/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,66 @@ def append_traceback(
The new traceback with the appended frame.

"""
frame = self._create_frame(filename, lineno, func)
return types.TracebackType(tb, frame, frame.f_lasti, lineno)

# This approach avoids binding the created frame object to a local variable
# in `append_traceback`, which would create a reference cycle. By using a
# nested function, the frame object is a temporary that is not held by
# the locals of `append_traceback`. See the diagram in `_with_append_backtrace`
# and PR #327 for more details.
def create(
tb: types.TracebackType | None, frame: types.FrameType, lineno: int
) -> types.TracebackType:
return types.TracebackType(tb, frame, frame.f_lasti, lineno)

return create(tb, self._create_frame(filename, lineno, func), lineno)


_TRACEBACK_MANAGER = TracebackManager()


def _with_append_backtrace(py_error: BaseException, backtrace: str) -> BaseException:
"""Append the backtrace to the py_error and return it."""
# We manually delete py_error and tb to avoid reference cycle, making it faster to gc the locals inside the frame
# please see pull request #327 for more details
#
# Memory Cycle Diagram:
#
# [Stack Frames] [Heap Objects]
# +-------------------+
# | outside functions | -----------------------> [ Tensor ]
# +-------------------+ (Held by cycle, slow to free)
# ^
# | f_back
# +-------------------+ locals py_error
# | py_error (this) | -----+--------------> [ BaseException ]
# +-------------------+ | |
# ^ | | (with_traceback)
# | f_back | v
# +-------------------+ +--------------> [ Traceback Obj ]
# | append_traceback | tb |
# +-------------------+ |
# ^ |
# | f_back |
# +-------------------+ |
# | _create_frame | |
# +-------------------+ |
# ^ |
# | f_back |
# +-------------------+ |
# | _get_frame | <----------------------------+
# +-------------------+ (Cycle closes here)
tb = py_error.__traceback__
for filename, lineno, func in _parse_backtrace(backtrace):
tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func)
return py_error.with_traceback(tb)
try:
for filename, lineno, func in _parse_backtrace(backtrace):
tb = _TRACEBACK_MANAGER.append_traceback(tb, filename, lineno, func)
return py_error.with_traceback(tb)
finally:
# We explicitly break the reference cycle here. The `finally` block is
# executed just before the function returns, after the `return` expression
# in the `try` block has been evaluated. Deleting `py_error` and `tb`
# here ensures they are not held by this function's frame's locals,
# which resolves the cycle.
del py_error, tb


def _traceback_to_backtrace_str(tb: types.TracebackType | None) -> str:
Expand Down
38 changes: 38 additions & 0 deletions tests/python/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.


import gc
import weakref
from typing import NoReturn

import pytest
Expand Down Expand Up @@ -113,3 +115,39 @@ def raise_cxx_error() -> None:
ffi_error2 = fecho(e)
assert ffi_error1.backtrace.find("raise_cxx_error") != -1
assert ffi_error2.backtrace.find("raise_cxx_error") != -1


def test_error_no_cyclic_reference() -> None:
# This test case ensures that when an error is raised from C++ side,
# there is no cyclic reference that slows down the garbage collection.
# Please see `_with_append_backtrace` in error.py

# temporarily disable gc
gc.disable()

try:
# We should create a class as a probe to detect gc activity
# beacuse weakref doesn't support list, dict or other trivial types
class SampleObject: ...

# trigger a C++ side KeyError by accessing a non-existent key
def trigger_cpp_side_error() -> None:
try:
tmp_map: tvm_ffi.Map = tvm_ffi.Map(dict())
tmp_map["a"]
except KeyError:
pass

def may_create_cyclic_reference() -> weakref.ReferenceType:
obj = SampleObject()
trigger_cpp_side_error()
return weakref.ref(obj)

wref = may_create_cyclic_reference()

# if the object is not collected, wref() will return the object
assert wref() is None, "Cyclic reference occurs inside error handling pipeline"

finally:
# re-enable gc whenever exception occurs
gc.enable()