diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0f7e5fcae6bd..887ac123ce61 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -201,6 +201,10 @@ cdef inline void* c_handle(object handle): # python env API cdef extern from "Python.h": int PyErr_CheckSignals() + void* PyGILState_Ensure() + void PyGILState_Release(void*) + void Py_IncRef(void*) + void Py_DecRef(void*) cdef extern from "tvm/runtime/c_backend_api.h": int TVMBackendRegisterEnvCAPI(const char* name, void* ptr) @@ -210,11 +214,13 @@ cdef _init_env_api(): # so backend can call tvm::runtime::EnvCheckSignals to check # signal when executing a long running function. # - # This feature is only enabled in cython for now due to problems of calling - # these functions in ctypes. - # - # When the functions are not registered, the signals will be handled - # only when the FFI function returns. + # Also registers the gil state release and ensure as PyErr_CheckSignals + # function is called with gil released and we need to regrab the gil CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyErr_CheckSignals"), PyErr_CheckSignals)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Ensure"), PyGILState_Ensure)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), PyGILState_Release)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("PyGILState_Release"), PyGILState_Release)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_IncRef"), Py_IncRef)) + CHECK_CALL(TVMBackendRegisterEnvCAPI(c_str("Py_DecRef"), Py_DecRef)) _init_env_api() diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 6e062ab5f199..b9516e79e36c 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -376,19 +376,3 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object): global _FUNC_CONVERT_TO_OBJECT _CLASS_OBJECT_GENERIC = object_generic_class _FUNC_CONVERT_TO_OBJECT = func_convert_to_object - -# Py_INCREF and Py_DECREF are C macros, not function objects. -# Therefore, providing a wrapper function. -cdef void _py_incref_wrapper(void* py_object): - Py_INCREF(py_object) -cdef void _py_decref_wrapper(void* py_object): - Py_DECREF(py_object) - -def _init_pythonapi_inc_def_ref(): - register_func = TVMBackendRegisterEnvCAPI - register_func(c_str("Py_IncRef"), _py_incref_wrapper) - register_func(c_str("Py_DecRef"), _py_decref_wrapper) - register_func(c_str("PyGILState_Ensure"), PyGILState_Ensure) - register_func(c_str("PyGILState_Release"), PyGILState_Release) - -_init_pythonapi_inc_def_ref() diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 0a034a7b5897..09674edf3584 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -183,10 +183,14 @@ class EnvCAPIRegistry { // implementation of tvm::runtime::EnvCheckSignals void CheckSignals() { // check python signal to see if there are exception raised - if (pyerr_check_signals != nullptr && (*pyerr_check_signals)() != 0) { - // The error will let FFI know that the frontend environment - // already set an error. - throw EnvErrorAlreadySet(""); + if (pyerr_check_signals != nullptr) { + // The C++ env comes without gil, so we need to grab gil here + WithGIL context(this); + if ((*pyerr_check_signals)() != 0) { + // The error will let FFI know that the frontend environment + // already set an error. + throw EnvErrorAlreadySet(""); + } } } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 928cdfcab80b..52ffedda8030 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -178,6 +178,14 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) { std::this_thread::sleep_for(duration); }); +TVM_REGISTER_GLOBAL("testing.check_signals").set_body_typed([](double sleep_period) { + while (true) { + std::chrono::duration duration(static_cast(sleep_period * 1e9)); + std::this_thread::sleep_for(duration); + runtime::EnvCheckSignals(); + } +}); + TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant { if (x % 2 == 0) { return IntImm(DataType::Int(64), x / 2);