Skip to content

Commit 91d69f0

Browse files
committed
[FFI] Support Opaque PyObject (apache#18270)
* [FFI] Support Opaque PyObject This PR adds support of Opaque PyObject. When a type in python is not natively supported by ffi, it will now be converted to an Opaque PyObject on the backend, such opaque object will retain their lifecycle automatically and can still be used by registering python callbacks or store in container and return to the frontend. * Round of grammar polishment
1 parent 1b07159 commit 91d69f0

File tree

10 files changed

+299
-93
lines changed

10 files changed

+299
-93
lines changed

include/tvm/ffi/c_api.h

Lines changed: 114 additions & 85 deletions
Large diffs are not rendered by default.

python/tvm_ffi/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ def convert(value: Any) -> Any:
5656
return None
5757
elif hasattr(value, "__dlpack__"):
5858
return core.from_dlpack(
59-
value,
60-
required_alignment=core.__dlpack_auto_import_required_alignment__,
59+
value, required_alignment=core.__dlpack_auto_import_required_alignment__
6160
)
6261
elif isinstance(value, Exception):
6362
return core._convert_to_ffi_error(value)
6463
else:
65-
raise TypeError(f"don't know how to convert type {type(value)} to object")
64+
# in this case, it is an opaque python object
65+
return core._convert_to_opaque_object(value)
6666

6767

6868
core._set_func_convert_to_object(convert)

python/tvm_ffi/cython/base.pxi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ cdef extern from "tvm/ffi/c_api.h":
5353
kTVMFFIArray = 71
5454
kTVMFFIMap = 72
5555
kTVMFFIModule = 73
56+
kTVMFFIOpaquePyObject = 74
5657

5758

5859
ctypedef void* TVMFFIObjectHandle
@@ -111,6 +112,9 @@ cdef extern from "tvm/ffi/c_api.h":
111112
const char* data
112113
size_t size
113114

115+
ctypedef struct TVMFFIOpaqueObjectCell:
116+
void* handle
117+
114118
ctypedef struct TVMFFIShapeCell:
115119
const int64_t* data
116120
size_t size
@@ -172,6 +176,8 @@ cdef extern from "tvm/ffi/c_api.h":
172176
const TVMFFITypeMetadata* metadata
173177

174178
int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
179+
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index,
180+
void (*deleter)(void*), TVMFFIObjectHandle* out) nogil
175181
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
176182
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t num_args,
177183
TVMFFIAny* result) nogil
@@ -203,6 +209,7 @@ cdef extern from "tvm/ffi/c_api.h":
203209
TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny* value) nogil
204210
TVMFFIByteArray* TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) nogil
205211
TVMFFIErrorCell* TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) nogil
212+
TVMFFIOpaqueObjectCell* TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) nogil
206213
TVMFFIShapeCell* TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) nogil
207214
DLTensor* TVMFFINDArrayGetDLTensorPtr(TVMFFIObjectHandle obj) nogil
208215
DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) nogil

python/tvm_ffi/cython/function.pxi

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ cdef inline object make_ret(TVMFFIAny result):
4646
if type_index == kTVMFFINDArray:
4747
# specially handle NDArray as it needs a special dltensor field
4848
return make_ndarray_from_any(result)
49+
elif type_index == kTVMFFIOpaquePyObject:
50+
return make_ret_opaque_object(result)
4951
elif type_index >= kTVMFFIStaticObjectBegin:
5052
return make_ret_object(result)
5153
elif type_index == kTVMFFINone:
@@ -182,7 +184,10 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
182184
out[i].v_ptr = (<Object>arg).chandle
183185
temp_args.append(arg)
184186
else:
185-
raise TypeError("Unsupported argument type: %s" % type(arg))
187+
arg = _convert_to_opaque_object(arg)
188+
out[i].type_index = kTVMFFIOpaquePyObject
189+
out[i].v_ptr = (<Object>arg).chandle
190+
temp_args.append(arg)
186191

187192

188193
cdef inline int FuncCall3(void* chandle,
@@ -431,9 +436,9 @@ def _get_global_func(name, allow_missing):
431436

432437

433438
# handle callbacks
434-
cdef void tvm_ffi_callback_deleter(void* fhandle) noexcept with gil:
435-
local_pyfunc = <object>(fhandle)
436-
Py_DECREF(local_pyfunc)
439+
cdef void tvm_ffi_pyobject_deleter(void* fhandle) noexcept with gil:
440+
local_pyobject = <object>(fhandle)
441+
Py_DECREF(local_pyobject)
437442

438443

439444
cdef int tvm_ffi_callback(void* context,
@@ -468,12 +473,27 @@ def _convert_to_ffi_func(object pyfunc):
468473
CHECK_CALL(TVMFFIFunctionCreate(
469474
<void*>(pyfunc),
470475
tvm_ffi_callback,
471-
tvm_ffi_callback_deleter,
476+
tvm_ffi_pyobject_deleter,
472477
&chandle))
473478
ret = Function.__new__(Function)
474479
(<Object>ret).chandle = chandle
475480
return ret
476481

482+
483+
def _convert_to_opaque_object(object pyobject):
484+
"""Convert a python object to TVM FFI opaque object"""
485+
cdef TVMFFIObjectHandle chandle
486+
Py_INCREF(pyobject)
487+
CHECK_CALL(TVMFFIObjectCreateOpaque(
488+
<void*>(pyobject),
489+
kTVMFFIOpaquePyObject,
490+
tvm_ffi_pyobject_deleter,
491+
&chandle))
492+
ret = OpaquePyObject.__new__(OpaquePyObject)
493+
(<Object>ret).chandle = chandle
494+
return ret
495+
496+
477497
_STR_CONSTRUCTOR = _get_global_func("ffi.String", False)
478498
_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False)
479499
_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True)

python/tvm_ffi/cython/object.pxi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,17 @@ cdef class Object:
194194
(<Object>other).chandle = NULL
195195

196196

197+
cdef class OpaquePyObject(Object):
198+
"""Opaque PyObject container"""
199+
def pyobject(self):
200+
"""Get the underlying python object"""
201+
cdef object obj
202+
cdef PyObject* py_handle
203+
py_handle = <PyObject*>(TVMFFIOpaqueObjectGetCellPtr(self.chandle).handle)
204+
obj = <object>py_handle
205+
return obj
206+
207+
197208
class PyNativeObject:
198209
"""Base class of all TVM objects that also subclass python's builtin types."""
199210
__slots__ = []
@@ -252,6 +263,12 @@ cdef inline str _type_index_to_key(int32_t tindex):
252263
return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size))
253264

254265

266+
cdef inline object make_ret_opaque_object(TVMFFIAny result):
267+
obj = OpaquePyObject.__new__(OpaquePyObject)
268+
(<Object>obj).chandle = result.v_obj
269+
return obj.pyobject()
270+
271+
255272
cdef inline object make_ret_object(TVMFFIAny result):
256273
global OBJECT_TYPE
257274
cdef int32_t tindex

src/ffi/object.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <tvm/ffi/container/map.h>
2525
#include <tvm/ffi/error.h>
2626
#include <tvm/ffi/function.h>
27+
#include <tvm/ffi/memory.h>
2728
#include <tvm/ffi/reflection/registry.h>
2829
#include <tvm/ffi/string.h>
2930

@@ -385,6 +386,29 @@ class TypeTable {
385386
Map<String, int64_t> type_attr_name_to_column_index_;
386387
};
387388

389+
/**
390+
* \brief Opaque implementation
391+
*/
392+
class OpaqueObjectImpl : public Object, public TVMFFIOpaqueObjectCell {
393+
public:
394+
OpaqueObjectImpl(void* handle, void (*deleter)(void* handle)) : deleter_(deleter) {
395+
this->handle = handle;
396+
}
397+
398+
void SetTypeIndex(int32_t type_index) {
399+
details::ObjectUnsafe::GetHeader(this)->type_index = type_index;
400+
}
401+
402+
~OpaqueObjectImpl() {
403+
if (deleter_ != nullptr) {
404+
deleter_(handle);
405+
}
406+
}
407+
408+
private:
409+
void (*deleter_)(void* handle);
410+
};
411+
388412
} // namespace ffi
389413
} // namespace tvm
390414

@@ -400,6 +424,22 @@ int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) {
400424
TVM_FFI_SAFE_CALL_END();
401425
}
402426

427+
int TVMFFIObjectCreateOpaque(void* handle, int32_t type_index, void (*deleter)(void* handle),
428+
TVMFFIObjectHandle* out) {
429+
TVM_FFI_SAFE_CALL_BEGIN();
430+
if (type_index != kTVMFFIOpaquePyObject) {
431+
TVM_FFI_THROW(RuntimeError) << "Only kTVMFFIOpaquePyObject is supported for now";
432+
}
433+
// create initial opaque object
434+
tvm::ffi::ObjectPtr<tvm::ffi::OpaqueObjectImpl> p =
435+
tvm::ffi::make_object<tvm::ffi::OpaqueObjectImpl>(handle, deleter);
436+
// need to set the type index after creation, because the set to RuntimeTypeIndex()
437+
// happens after the constructor is called
438+
p->SetTypeIndex(type_index);
439+
*out = tvm::ffi::details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(p));
440+
TVM_FFI_SAFE_CALL_END();
441+
}
442+
403443
int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex) {
404444
TVM_FFI_SAFE_CALL_BEGIN();
405445
out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key);

tests/cpp/test_object.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,4 +222,29 @@ TEST(Object, WeakObjectPtrAssignment) {
222222
EXPECT_EQ(lock3->value, 777);
223223
}
224224

225+
TEST(Object, OpaqueObject) {
226+
thread_local int deleter_trigger_counter = 0;
227+
struct DummyOpaqueObject {
228+
int value;
229+
DummyOpaqueObject(int value) : value(value) {}
230+
231+
static void Deleter(void* handle) {
232+
deleter_trigger_counter++;
233+
delete static_cast<DummyOpaqueObject*>(handle);
234+
}
235+
};
236+
TVMFFIObjectHandle handle = nullptr;
237+
TVM_FFI_CHECK_SAFE_CALL(TVMFFIObjectCreateOpaque(new DummyOpaqueObject(10), kTVMFFIOpaquePyObject,
238+
DummyOpaqueObject::Deleter, &handle));
239+
ObjectPtr<Object> a =
240+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<Object*>(handle));
241+
EXPECT_EQ(a->type_index(), kTVMFFIOpaquePyObject);
242+
EXPECT_EQ(static_cast<DummyOpaqueObject*>(TVMFFIOpaqueObjectGetCellPtr(a.get())->handle)->value,
243+
10);
244+
EXPECT_EQ(a.use_count(), 1);
245+
EXPECT_EQ(deleter_trigger_counter, 0);
246+
a.reset();
247+
EXPECT_EQ(deleter_trigger_counter, 1);
248+
}
249+
225250
} // namespace

tests/python/test_container.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,28 @@ def test_int_map():
6666
assert tuple(amap.values()) == (2, 3)
6767

6868

69+
def test_array_map_of_opaque_object():
70+
class MyObject:
71+
def __init__(self, value):
72+
self.value = value
73+
74+
a = tvm_ffi.convert([MyObject("hello"), MyObject(1)])
75+
assert isinstance(a, tvm_ffi.Array)
76+
assert len(a) == 2
77+
assert isinstance(a[0], MyObject)
78+
assert a[0].value == "hello"
79+
assert isinstance(a[1], MyObject)
80+
assert a[1].value == 1
81+
82+
y = tvm_ffi.convert({"a": MyObject(1), "b": MyObject("hello")})
83+
assert isinstance(y, tvm_ffi.Map)
84+
assert len(y) == 2
85+
assert isinstance(y["a"], MyObject)
86+
assert y["a"].value == 1
87+
assert isinstance(y["b"], MyObject)
88+
assert y["b"].value == "hello"
89+
90+
6991
def test_str_map():
7092
data = []
7193
for i in reversed(range(10)):

tests/python/test_function.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import gc
1919
import ctypes
20+
import sys
2021
import numpy as np
2122
import tvm_ffi
2223

@@ -161,3 +162,27 @@ def check1():
161162

162163
check0()
163164
check1()
165+
166+
167+
def test_echo_with_opaque_object():
168+
class MyObject:
169+
def __init__(self, value):
170+
self.value = value
171+
172+
fecho = tvm_ffi.get_global_func("testing.echo")
173+
x = MyObject("hello")
174+
assert sys.getrefcount(x) == 2
175+
y = fecho(x)
176+
assert isinstance(y, MyObject)
177+
assert y is x
178+
assert sys.getrefcount(x) == 3
179+
180+
def py_callback(z):
181+
"""python callback with opaque object"""
182+
assert z is x
183+
return z
184+
185+
fcallback = tvm_ffi.convert(py_callback)
186+
z = fcallback(x)
187+
assert z is x
188+
assert sys.getrefcount(x) == 4

tests/python/test_object.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import pytest
18+
import sys
1819

1920
import tvm_ffi
2021

@@ -68,3 +69,23 @@ def test_derived_object():
6869

6970
obj0.v_i64 = 21
7071
assert obj0.v_i64 == 21
72+
73+
74+
class MyObject:
75+
def __init__(self, value):
76+
self.value = value
77+
78+
79+
def test_opaque_object():
80+
obj0 = MyObject("hello")
81+
assert sys.getrefcount(obj0) == 2
82+
obj0_converted = tvm_ffi.convert(obj0)
83+
assert sys.getrefcount(obj0) == 3
84+
assert isinstance(obj0_converted, tvm_ffi.core.OpaquePyObject)
85+
obj0_cpy = obj0_converted.pyobject()
86+
assert obj0_cpy is obj0
87+
assert sys.getrefcount(obj0) == 4
88+
obj0_converted = None
89+
assert sys.getrefcount(obj0) == 3
90+
obj0_cpy = None
91+
assert sys.getrefcount(obj0) == 2

0 commit comments

Comments
 (0)