From f006162d5526de13db4e9436d9474196e86edcb9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 13 Sep 2025 00:01:36 -0400 Subject: [PATCH] [FFI][ABI][REFACTOR] Better String and nested container handling This PR improves the overall String/Bytes and nested container handling It also fixes a bug for temp object recycling when temp object. - Introduce formal API for string/bytes creation - Updates the tuple/dict conversion to also preserve the torch stream - So if a function takes a list of torch.Tensor, torch stream will be setup in context - Optimizes recursive argument conversion by moving most logic into c++ --- ffi/include/tvm/ffi/c_api.h | 19 ++ ffi/pyproject.toml | 2 +- ffi/python/tvm_ffi/_convert.py | 11 +- .../tvm_ffi/_optional_torch_c_dlpack.py | 1 - ffi/python/tvm_ffi/cython/base.pxi | 11 + ffi/python/tvm_ffi/cython/function.pxi | 267 ++++++++++++++---- ffi/python/tvm_ffi/cython/object.pxi | 11 +- ffi/python/tvm_ffi/cython/string.pxi | 5 - .../tvm_ffi/cython/tvm_ffi_python_helpers.h | 101 ++++++- ffi/src/ffi/object.cc | 18 ++ ffi/tests/python/test_function.py | 33 +++ ffi/tests/python/test_load_inline.py | 13 +- src/runtime/disco/protocol.h | 6 +- src/runtime/minrpc/rpc_reference.h | 4 +- src/runtime/rpc/rpc_endpoint.cc | 44 ++- 15 files changed, 432 insertions(+), 114 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index a53dac4d00af..f13f820b7fc9 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -555,6 +555,25 @@ TVM_FFI_DLL int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* from, */ TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from, DLManagedTensorVersioned** out); +//--------------------------------------------------------------- +// Section: string/bytes support APIs. +// These APIs are used to simplify the string/bytes construction +//--------------------------------------------------------------- +/*! + * \brief Reinterpret the content of TVMFFIByteArray to String. + * \param input The TVMFFIByteArray to convert. + * \param out The output String owned by the caller, maybe a SmallStr or a Str object. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIStringFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out); + +/*! + * \brief Reinterpret the content of TVMFFIByteArray to Bytes. + * \param input The TVMFFIByteArray to convert. + * \param out The output Bytes owned by the caller, maybe a SmallBytes or a Bytes object. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out); //--------------------------------------------------------------- // Section: dtype string support APIs. diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml index 8c146f41c4e2..cc2df03f0a6b 100644 --- a/ffi/pyproject.toml +++ b/ffi/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "apache-tvm-ffi" -version = "0.1.0a12" +version = "0.1.0a13" description = "tvm ffi" authors = [{ name = "TVM FFI team" }] diff --git a/ffi/python/tvm_ffi/_convert.py b/ffi/python/tvm_ffi/_convert.py index b1b972633d86..a0b6c1b117e5 100644 --- a/ffi/python/tvm_ffi/_convert.py +++ b/ffi/python/tvm_ffi/_convert.py @@ -40,13 +40,9 @@ def convert(value: Any) -> Any: automatically converted. So this function is mainly only used in internal or testing scenarios. """ - if isinstance(value, core.Object): + if isinstance(value, (core.Object, core.PyNativeObject, bool, Number)): return value - elif isinstance(value, core.PyNativeObject): - return value - elif isinstance(value, (bool, Number)): - return value - elif isinstance(value, (list, tuple)): + elif isinstance(value, (tuple, list)): return container.Array(value) elif isinstance(value, dict): return container.Map(value) @@ -67,6 +63,3 @@ def convert(value: Any) -> Any: else: # in this case, it is an opaque python object return core._convert_to_opaque_object(value) - - -core._set_func_convert_to_object(convert) diff --git a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py index fc5851af170d..f44855247abe 100644 --- a/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py +++ b/ffi/python/tvm_ffi/_optional_torch_c_dlpack.py @@ -384,7 +384,6 @@ def load_torch_c_dlpack_extension(): ], extra_cflags=["-O3"], extra_include_paths=libinfo.include_paths() + cpp_extension.include_paths("cuda"), - verbose=True, ) # set the dlpack related flags torch.Tensor.__c_dlpack_from_pyobject__ = mod.TorchDLPackFromPyObjectPtr() diff --git a/ffi/python/tvm_ffi/cython/base.pxi b/ffi/python/tvm_ffi/cython/base.pxi index fdb06f51055e..ef583c752908 100644 --- a/ffi/python/tvm_ffi/cython/base.pxi +++ b/ffi/python/tvm_ffi/cython/base.pxi @@ -212,6 +212,8 @@ cdef extern from "tvm/ffi/c_api.h": TVMFFIByteArray* traceback) nogil int TVMFFITypeKeyToIndex(TVMFFIByteArray* type_key, int32_t* out_tindex) nogil + int TVMFFIStringFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil + int TVMFFIBytesFromByteArray(TVMFFIByteArray* input_, TVMFFIAny* out) nogil int TVMFFIDataTypeFromString(TVMFFIByteArray* str, DLDataType* out) nogil int TVMFFIDataTypeToString(const DLDataType* dtype, TVMFFIAny* out) nogil const TVMFFIByteArray* TVMFFITraceback( @@ -284,6 +286,15 @@ cdef extern from "tvm_ffi_python_helpers.h": DLPackToPyObject* out_dlpack_importer ) except -1 + int TVMFFIPyConstructorCall( + TVMFFIPyArgSetterFactory setter_factory, + void* chandle, + PyObject* py_arg_tuple, + TVMFFIAny* result, + int* c_api_ret_code, + TVMFFIPyCallContext* parent_ctx + ) except -1 + int TVMFFIPyCallFieldSetter( TVMFFIPyArgSetterFactory setter_factory, TVMFFIFieldSetter field_setter, diff --git a/ffi/python/tvm_ffi/cython/function.pxi b/ffi/python/tvm_ffi/cython/function.pxi index 9b86054b7102..71c9522ddba4 100644 --- a/ffi/python/tvm_ffi/cython/function.pxi +++ b/ffi/python/tvm_ffi/cython/function.pxi @@ -88,6 +88,27 @@ cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject c_dlpack_to_pyobj raise ValueError("Unhandled type index %d" % type_index) +##---------------------------------------------------------------------------- +## Helper to simplify calling constructor +##---------------------------------------------------------------------------- +cdef inline int ConstructorCall(void* constructor_handle, + PyObject* py_arg_tuple, + void** handle, + TVMFFIPyCallContext* parent_ctx) except -1: + """Call contructor of a handle function""" + cdef TVMFFIAny result + cdef int c_api_ret_code + # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone + result.type_index = kTVMFFINone + result.v_int64 = 0 + TVMFFIPyConstructorCall( + TVMFFIPyArgSetterFactory_, constructor_handle, py_arg_tuple, &result, &c_api_ret_code, + parent_ctx + ) + CHECK_CALL(c_api_ret_code) + handle[0] = result.v_ptr + return 0 + ##---------------------------------------------------------------------------- ## Implementation of setters using same naming style as TVMFFIPyArgSetterXXX_ ##---------------------------------------------------------------------------- @@ -244,18 +265,33 @@ cdef int TVMFFIPyArgSetterStr_( ) except -1: """Setter for str""" cdef object arg = py_arg + cdef bytes tstr = arg.encode("utf-8") + cdef char* data + cdef Py_ssize_t size + cdef TVMFFIByteArray cdata + + PyBytes_AsStringAndSize(tstr, &data, &size) + cdata.data = data + cdata.size = size + CHECK_CALL(TVMFFIStringFromByteArray(&cdata, out)) + if out.type_index >= kTVMFFIStaticObjectBegin: + TVMFFIPyPushTempFFIObject(ctx, out.v_ptr) + return 0 + - if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: +cdef int TVMFFIPyArgSetterPyNativeObjectStr_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Specially handle String as its __tvm_ffi_object__ may be empty""" + cdef object arg = py_arg + # need to check if the arg is a large string returned from ffi + if arg.__tvm_ffi_object__ is not None: arg = arg.__tvm_ffi_object__ out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out.v_ptr = (arg).chandle return 0 - - tstr = c_str(arg) - out.type_index = kTVMFFIRawStr - out.v_c_str = tstr - TVMFFIPyPushTempPyObject(ctx, tstr) - return 0 + return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out) cdef int TVMFFIPyArgSetterBytes_( @@ -265,17 +301,50 @@ cdef int TVMFFIPyArgSetterBytes_( """Setter for bytes""" cdef object arg = py_arg - if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: + if isinstance(arg, bytearray): + arg = bytes(arg) + + cdef char* data + cdef Py_ssize_t size + cdef TVMFFIByteArray cdata + + PyBytes_AsStringAndSize(arg, &data, &size) + cdata.data = data + cdata.size = size + CHECK_CALL(TVMFFIBytesFromByteArray(&cdata, out)) + + if out.type_index >= kTVMFFIStaticObjectBegin: + TVMFFIPyPushTempFFIObject(ctx, out.v_ptr) + return 0 + + +cdef int TVMFFIPyArgSetterPyNativeObjectBytes_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Specially handle Bytes as its __tvm_ffi_object__ may be empty""" + cdef object arg = py_arg + # need to check if the arg is a large bytes returned from ffi + if arg.__tvm_ffi_object__ is not None: arg = arg.__tvm_ffi_object__ out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) out.v_ptr = (arg).chandle return 0 + return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out) - arg = ByteArrayArg(arg) - out.type_index = kTVMFFIByteArrayPtr - out.v_int64 = 0 - out.v_ptr = (arg).cptr() - TVMFFIPyPushTempPyObject(ctx, arg) + +cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Specially handle Bytes as its __tvm_ffi_object__ may be empty""" + cdef object arg = py_arg + if arg.__tvm_ffi_object__ is None: + raise ValueError(f"__tvm_ffi_object__ is None for {type(arg)}") + assert arg.__tvm_ffi_object__ is not None + arg = arg.__tvm_ffi_object__ + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 @@ -306,10 +375,11 @@ cdef int TVMFFIPyArgSetterCallable_( ) except -1: """Setter for Callable""" cdef object arg = py_arg - arg = _convert_to_ffi_func(arg) - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - TVMFFIPyPushTempPyObject(ctx, arg) + cdef TVMFFIObjectHandle chandle + _convert_to_ffi_func_handle(arg, &chandle) + out.type_index = TVMFFIObjectGetTypeIndex(chandle) + out.v_ptr = chandle + TVMFFIPyPushTempFFIObject(ctx, chandle) return 0 @@ -326,27 +396,79 @@ cdef int TVMFFIPyArgSetterException_( return 0 +cdef int TVMFFIPyArgSetterTuple_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for Tuple""" + # recursively construct a new tuple + cdef TVMFFIObjectHandle chandle + ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, py_arg, &chandle, ctx) + out.type_index = TVMFFIObjectGetTypeIndex(chandle) + out.v_ptr = chandle + TVMFFIPyPushTempFFIObject(ctx, chandle) + return 0 + + +cdef int TVMFFIPyArgSetterTupleLike_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for TupleLike""" + # recursively construct a new tuple + cdef tuple tuple_arg = tuple(py_arg) + cdef TVMFFIObjectHandle chandle + ConstructorCall(_CONSTRUCTOR_ARRAY.chandle, tuple_arg, &chandle, ctx) + out.type_index = TVMFFIObjectGetTypeIndex(chandle) + out.v_ptr = chandle + TVMFFIPyPushTempFFIObject(ctx, chandle) + return 0 + + +cdef int TVMFFIPyArgSetterMap_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for Map""" + # recursively construct a new map + cdef dict dict_arg = py_arg + cdef list list_kvs = [] + for k, v in dict_arg.items(): + list_kvs.append(k) + list_kvs.append(v) + cdef tuple_arg_kvs = tuple(list_kvs) + cdef TVMFFIObjectHandle chandle + ConstructorCall(_CONSTRUCTOR_MAP.chandle, tuple_arg_kvs, &chandle, ctx) + out.type_index = TVMFFIObjectGetTypeIndex(chandle) + out.v_ptr = chandle + TVMFFIPyPushTempFFIObject(ctx, chandle) + return 0 + + +cdef int TVMFFIPyArgSetterObjectConvertible_( + TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, + PyObject* py_arg, TVMFFIAny* out +) except -1: + """Setter for ObjectConvertible""" + # recursively construct a new map + cdef object arg = py_arg + arg = arg.asobject() + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle + TVMFFIPyPushTempPyObject(ctx, arg) + + cdef int TVMFFIPyArgSetterFallback_( TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* py_arg, TVMFFIAny* out ) except -1: """Fallback setter for all other types""" cdef object arg = py_arg - # fallback must contain PyNativeObject check - if isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None: - arg = arg.__tvm_ffi_object__ - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - elif isinstance(arg, (list, tuple, dict, ObjectConvertible)): - arg = _FUNC_CONVERT_TO_OBJECT(arg) - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle - TVMFFIPyPushTempPyObject(ctx, arg) - else: - arg = _convert_to_opaque_object(arg) - out.type_index = kTVMFFIOpaquePyObject - out.v_ptr = (arg).chandle - TVMFFIPyPushTempPyObject(ctx, arg) + cdef TVMFFIObjectHandle chandle + _convert_to_opaque_object_handle(arg, &chandle) + out.type_index = kTVMFFIOpaquePyObject + out.v_ptr = chandle + TVMFFIPyPushTempFFIObject(ctx, chandle) cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) except -1: @@ -407,12 +529,32 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce if isinstance(arg, _CLASS_DEVICE): out.func = TVMFFIPyArgSetterDevice_ return 0 + if isinstance(arg, PyNativeObject): + # check for PyNativeObject + # this check must happen before str/bytes/tuple + if isinstance(arg, str): + out.func = TVMFFIPyArgSetterPyNativeObjectStr_ + return 0 + if isinstance(arg, bytes): + out.func = TVMFFIPyArgSetterPyNativeObjectBytes_ + return 0 + out.func = TVMFFIPyArgSetterPyNativeObjectGeneral_ + return 0 if isinstance(arg, str): out.func = TVMFFIPyArgSetterStr_ return 0 if isinstance(arg, (bytes, bytearray)): out.func = TVMFFIPyArgSetterBytes_ return 0 + if isinstance(arg, tuple): + out.func = TVMFFIPyArgSetterTuple_ + return 0 + if isinstance(arg, list): + out.func = TVMFFIPyArgSetterTupleLike_ + return 0 + if isinstance(arg, dict): + out.func = TVMFFIPyArgSetterMap_ + return 0 if isinstance(arg, ctypes.c_void_p): out.func = TVMFFIPyArgSetterCtypesVoidPtr_ return 0 @@ -422,6 +564,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce if isinstance(arg, Exception): out.func = TVMFFIPyArgSetterException_ return 0 + if isinstance(arg, ObjectConvertible): + out.func = TVMFFIPyArgSetterObjectConvertible_ + return 0 # default to opaque object out.func = TVMFFIPyArgSetterFallback_ return 0 @@ -429,24 +574,6 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce #--------------------------------------------------------------------------------------------- ## Implementation of function calling #--------------------------------------------------------------------------------------------- -cdef inline int ConstructorCall(void* constructor_handle, - tuple args, - void** handle) except -1: - """Call contructor of a handle function""" - cdef TVMFFIAny result - cdef int c_api_ret_code - # IMPORTANT: caller need to initialize result->type_index to kTVMFFINone - result.type_index = kTVMFFINone - result.v_int64 = 0 - TVMFFIPyFuncCall( - TVMFFIPyArgSetterFactory_, constructor_handle, args, &result, &c_api_ret_code, - False, NULL - ) - CHECK_CALL(c_api_ret_code) - handle[0] = result.v_ptr - return 0 - - cdef class Function(Object): """Python class that wraps a function with tvm-ffi ABI. @@ -670,29 +797,45 @@ cdef int tvm_ffi_callback(void* context, return -1 -def _convert_to_ffi_func(object pyfunc): - """Convert a python function to TVM FFI function""" - cdef TVMFFIObjectHandle chandle +cdef inline int _convert_to_ffi_func_handle( + object pyfunc, TVMFFIObjectHandle* out_handle +) except -1: + """Convert a python function to TVM FFI function handle""" Py_INCREF(pyfunc) CHECK_CALL(TVMFFIFunctionCreate( (pyfunc), tvm_ffi_callback, tvm_ffi_pyobject_deleter, - &chandle)) + out_handle)) + return 0 + + +def _convert_to_ffi_func(object pyfunc): + """Convert a python function to TVM FFI function""" + cdef TVMFFIObjectHandle chandle + _convert_to_ffi_func_handle(pyfunc, &chandle) ret = Function.__new__(Function) (ret).chandle = chandle return ret -def _convert_to_opaque_object(object pyobject): - """Convert a python object to TVM FFI opaque object""" - cdef TVMFFIObjectHandle chandle +cdef inline int _convert_to_opaque_object_handle( + object pyobject, TVMFFIObjectHandle* out_handle +) except -1: + """Convert a python object to TVM FFI opaque object handle""" Py_INCREF(pyobject) CHECK_CALL(TVMFFIObjectCreateOpaque( (pyobject), kTVMFFIOpaquePyObject, tvm_ffi_pyobject_deleter, - &chandle)) + out_handle)) + return 0 + + +def _convert_to_opaque_object(object pyobject): + """Convert a python object to TVM FFI opaque object""" + cdef TVMFFIObjectHandle chandle + _convert_to_opaque_object_handle(pyobject, &chandle) ret = OpaquePyObject.__new__(OpaquePyObject) (ret).chandle = chandle return ret @@ -704,7 +847,7 @@ def _print_debug_info(): print(f"TVMFFIPyGetDispatchMapSize: {size}") -_STR_CONSTRUCTOR = _get_global_func("ffi.String", False) -_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False) -_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) -_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) +cdef Function _OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True) +cdef Function _OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True) +cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True) +cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True) diff --git a/ffi/python/tvm_ffi/cython/object.pxi b/ffi/python/tvm_ffi/cython/object.pxi index 2a306e01ee68..1d026b250fb7 100644 --- a/ffi/python/tvm_ffi/cython/object.pxi +++ b/ffi/python/tvm_ffi/cython/object.pxi @@ -17,17 +17,12 @@ import warnings _CLASS_OBJECT = None -_FUNC_CONVERT_TO_OBJECT = None def _set_class_object(cls): global _CLASS_OBJECT _CLASS_OBJECT = cls -def _set_func_convert_to_object(func): - global _FUNC_CONVERT_TO_OBJECT - _FUNC_CONVERT_TO_OBJECT = func - def __object_repr__(obj): """Object repr function that can be overridden by assigning to it""" @@ -39,10 +34,6 @@ def _new_object(cls): return cls.__new__(cls) -_OBJECT_FROM_JSON_GRAPH_STR = None -_OBJECT_TO_JSON_GRAPH_STR = None - - class ObjectConvertible: """Base class for all classes that can be converted to object.""" @@ -144,7 +135,7 @@ cdef class Object: self.chandle = NULL cdef void* chandle ConstructorCall( - (fconstructor).chandle, args, &chandle) + (fconstructor).chandle, args, &chandle, NULL) self.chandle = chandle def same_as(self, other): diff --git a/ffi/python/tvm_ffi/cython/string.pxi b/ffi/python/tvm_ffi/cython/string.pxi index 4ab5c48ce07b..0737259f22e2 100644 --- a/ffi/python/tvm_ffi/cython/string.pxi +++ b/ffi/python/tvm_ffi/cython/string.pxi @@ -78,8 +78,3 @@ class Bytes(bytes, PyNativeObject): _register_object_by_index(kTVMFFIBytes, Bytes) - -# We special handle str/bytes constructor in cython to avoid extra cyclic deps -# as the str/bytes construction must be done in the inner loop of function call -_STR_CONSTRUCTOR = None -_BYTES_CONSTRUCTOR = None diff --git a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h index 87b426829d1a..325b878c4fc9 100644 --- a/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h +++ b/ffi/python/tvm_ffi/cython/tvm_ffi_python_helpers.h @@ -226,10 +226,7 @@ class TVMFFIPyCallManager { try { // recycle the temporary arguments if any for (int i = 0; i < this->num_temp_ffi_objects; ++i) { - TVMFFIObject* obj = static_cast(this->temp_ffi_objects[i]); - if (obj->deleter != nullptr) { - obj->deleter(obj, kTVMFFIObjectDeleterFlagBitMaskBoth); - } + TVMFFIObjectDecRef(this->temp_ffi_objects[i]); } for (int i = 0; i < this->num_temp_py_objects; ++i) { Py_DecRef(static_cast(this->temp_py_objects[i])); @@ -270,9 +267,9 @@ class TVMFFIPyCallManager { * \return 0 on when there is no python error, -1 on python error * \note When an error happens on FFI side, we should return 0 and set c_api_ret_code */ - int Call(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, - TVMFFIAny* result, int* c_api_ret_code, bool release_gil, - DLPackToPyObject* optional_out_dlpack_importer) { + int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, PyObject* py_arg_tuple, + TVMFFIAny* result, int* c_api_ret_code, bool release_gil, + DLPackToPyObject* optional_out_dlpack_importer) { int64_t num_args = PyTuple_Size(py_arg_tuple); if (num_args == -1) return -1; try { @@ -331,6 +328,64 @@ class TVMFFIPyCallManager { } } + /* + * \brief Call a constructor with a variable number of arguments + * + * This function is similar to FuncCall, but it will not set the + * stream and tensor allocator, instead, it will synchronize the TVMFFIPyCallContext + * with the parent context. This behavior is needed for nested conversion of arguments + * where detected argument setting needs to be synchronized with final call. + * + * This function will also not release the GIL since constructor call is usually cheap. + * + * \param setter_factory The factory function to create the setter + * \param func_handle The handle of the constructor to call + * \param py_arg_tuple The arguments to the constructor + * \param result The result of the constructor + * \param c_api_ret_code The return code of the constructor + * \param parent_ctx The parent call context to + * \return 0 on success, -1 on failure + */ + int ConstructorCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, + PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, + TVMFFIPyCallContext* parent_ctx) { + int64_t num_args = PyTuple_Size(py_arg_tuple); + if (num_args == -1) return -1; + try { + // allocate a call stack + CallStack ctx(this, num_args); + // Iterate over the arguments and set them + for (int64_t i = 0; i < num_args; ++i) { + PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i); + TVMFFIAny* c_arg = ctx.packed_args + i; + if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1; + } + c_api_ret_code[0] = TVMFFIFunctionCall(func_handle, ctx.packed_args, num_args, result); + // propagate the call context to the parent context + if (parent_ctx != nullptr) { + // stream and current device information + if (parent_ctx->device_type == -1) { + parent_ctx->device_type = ctx.device_type; + parent_ctx->device_id = ctx.device_id; + parent_ctx->stream = ctx.stream; + } + // DLPack allocator + if (parent_ctx->c_dlpack_tensor_allocator == nullptr) { + parent_ctx->c_dlpack_tensor_allocator = ctx.c_dlpack_tensor_allocator; + } + // DLPack importer + if (parent_ctx->c_dlpack_to_pyobject == nullptr) { + parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject; + } + } + return 0; + } catch (const std::exception& ex) { + // very rare, catch c++ exception and set python error + PyErr_SetString(PyExc_RuntimeError, ex.what()); + return -1; + } + } + int SetField(TVMFFIPyArgSetterFactory setter_factory, TVMFFIFieldSetter field_setter, void* field_ptr, PyObject* py_arg, int* c_api_ret_code) { try { @@ -430,8 +485,36 @@ inline int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, void* func_ PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, bool release_gil = true, DLPackToPyObject* out_dlpack_importer = nullptr) { - return TVMFFIPyCallManager::ThreadLocal()->Call(setter_factory, func_handle, py_arg_tuple, result, - c_api_ret_code, release_gil, out_dlpack_importer); + return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory, func_handle, py_arg_tuple, + result, c_api_ret_code, release_gil, + out_dlpack_importer); +} + +/*! + * \brief Call a constructor function with a variable number of arguments + * + * This function is similar to TVMFFIPyFuncCall, but it will not set the + * stream and tensor allocator. Instead, it will synchronize the TVMFFIPyCallContext + * with the parent context. This behavior is needed for nested conversion of arguments + * where detected argument settings need to be synchronized with the final call. + * + * This function will also not release the GIL since constructor call is usually cheap. + * + * \param setter_factory The factory function to create the setter + * \param func_handle The handle of the function to call + * \param py_arg_tuple The arguments to the constructor + * \param result The result of the constructor + * \param c_api_ret_code The return code of the constructor + * \param parent_ctx The parent call context + * \param release_gil Whether to release the GIL + * \param out_dlpack_exporter The DLPack exporter to be used for the result + * \return 0 on success, nonzero on failure + */ +inline int TVMFFIPyConstructorCall(TVMFFIPyArgSetterFactory setter_factory, void* func_handle, + PyObject* py_arg_tuple, TVMFFIAny* result, int* c_api_ret_code, + TVMFFIPyCallContext* parent_ctx) { + return TVMFFIPyCallManager::ThreadLocal()->ConstructorCall( + setter_factory, func_handle, py_arg_tuple, result, c_api_ret_code, parent_ctx); } /*! diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 9f554e3356f9..292c8e913f1d 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -493,3 +493,21 @@ const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index); TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); } + +// string APIs, we blend into object.cc to keep things simple +int TVMFFIStringFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + // must set to none first + out->type_index = kTVMFFINone; + tvm::ffi::TypeTraits::MoveToAny(tvm::ffi::String(input->data, input->size), + out); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIBytesFromByteArray(const TVMFFIByteArray* input, TVMFFIAny* out) { + TVM_FFI_SAFE_CALL_BEGIN(); + // must set to none first + out->type_index = kTVMFFINone; + tvm::ffi::TypeTraits::MoveToAny(tvm::ffi::Bytes(input->data, input->size), out); + TVM_FFI_SAFE_CALL_END(); +} diff --git a/ffi/tests/python/test_function.py b/ffi/tests/python/test_function.py index dfe22a1bad80..b5a1da4f7d1d 100644 --- a/ffi/tests/python/test_function.py +++ b/ffi/tests/python/test_function.py @@ -97,6 +97,39 @@ def test_return_raw_str_bytes(): assert tvm_ffi.convert(lambda: bytearray(b"hello"))() == b"hello" +def test_string_bytes_passing(): + fecho = tvm_ffi.get_global_func("testing.echo") + use_count = tvm_ffi.get_global_func("testing.object_use_count") + # small string + assert fecho("hello") == "hello" + # large string + x = "hello" * 100 + y = fecho(x) + assert y == x + assert y.__tvm_ffi_object__ is not None + use_count(y) == 1 + # small bytes + assert fecho(b"hello") == b"hello" + # large bytes + x = b"hello" * 100 + y = fecho(x) + assert y == x + assert y.__tvm_ffi_object__ is not None + fecho(y) == 1 + + +def test_nested_container_passing(): + # test and make sure our ref counting is correct + fecho = tvm_ffi.get_global_func("testing.echo") + use_count = tvm_ffi.get_global_func("testing.object_use_count") + obj = tvm_ffi.convert((1, 2, 3)) + assert use_count(obj) == 1 + y = fecho([obj, {"a": 1, "b": obj}]) + assert use_count(y) == 1 + assert use_count(obj) == 3 + assert use_count(y[1]) == 2 + + def test_pyfunc_convert(): def add(a, b): return a + b diff --git a/ffi/tests/python/test_load_inline.py b/ffi/tests/python/test_load_inline.py index 2aa01a62ee1d..0277803730dc 100644 --- a/ffi/tests/python/test_load_inline.py +++ b/ffi/tests/python/test_load_inline.py @@ -207,13 +207,10 @@ def test_load_inline_cuda_with_env_tensor_allocator(): pytest.skip("Torch does not support __c_dlpack_tensor_allocator__") mod: Module = tvm_ffi.cpp.load_inline( name="hello", - cpp_sources=r""" - #include - - tvm::ffi::Tensor return_add_one(DLTensor* x); - """, cuda_sources=r""" #include + #include + #include __global__ void AddOneKernel(float* x, float* y, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -223,7 +220,8 @@ def test_load_inline_cuda_with_env_tensor_allocator(): } namespace ffi = tvm::ffi; - ffi::Tensor return_add_one(DLTensor* x) { + ffi::Tensor return_add_one(ffi::Map> kwargs) { + ffi::Tensor x = kwargs["x"].get<0>(); // implementation of a library function TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; DLDataType f32_dtype{kDLFloat, 32, 1}; @@ -251,7 +249,8 @@ def test_load_inline_cuda_with_env_tensor_allocator(): if torch is not None: x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda") - y_cuda = mod.return_add_one(x_cuda) + # test support for nested container passing + y_cuda = mod.return_add_one({"x": [x_cuda]}) assert isinstance(y_cuda, torch.Tensor) assert y_cuda.shape == (5,) assert y_cuda.dtype == torch.float32 diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index e36935c8d27a..067a4f0d4a67 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -49,7 +49,7 @@ struct DiscoProtocol { /*! \brief Recycle all the memory used in the arena */ inline void RecycleAll() { - this->object_arena_.clear(); + this->any_arena_.clear(); this->arena_.RecycleAll(); } @@ -81,7 +81,7 @@ struct DiscoProtocol { } support::Arena arena_; - std::vector object_arena_; + std::vector any_arena_; friend struct RPCReference; }; @@ -213,7 +213,7 @@ inline void DiscoProtocol::ReadFFIAny(TVMFFIAny* out) { << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; } *reinterpret_cast(out) = result; - object_arena_.push_back(result); + any_arena_.push_back(result); } inline std::string DiscoDebugObject::SaveToStr() const { diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index ee08ad12c736..8b21b2492716 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -472,7 +472,9 @@ struct RPCReference { break; } default: { - if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { + if (type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin || + type_index == ffi::TypeIndex::kTVMFFISmallStr || + type_index == ffi::TypeIndex::kTVMFFISmallBytes) { channel->ReadFFIAny(&(packed_args[i])); } else { channel->ThrowError(RPCServerStatus::kUnknownTypeIndex); diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index c51484b2790f..0778b5539474 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -171,6 +171,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { for (int i = 0; i < args.size(); ++i) { if (args[i] == nullptr) continue; if (args[i].type_index() == ffi::TypeIndex::kTVMFFIModule) continue; + if (args[i].type_index() == ffi::TypeIndex::kTVMFFISmallStr || + args[i].type_index() == ffi::TypeIndex::kTVMFFISmallBytes) + continue; + if (args[i].type_index() == ffi::TypeIndex::kTVMFFIStr || + args[i].type_index() == ffi::TypeIndex::kTVMFFIBytes) + continue; if (const Object* obj = args[i].as()) { if (!obj->IsInstance()) { LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " << obj->GetTypeKey() @@ -221,14 +227,20 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { void WriteFFIAny(const TVMFFIAny* in) { // NOTE: for now all remote object are encoded as RPCObjectRef // follow the same disco protocol in case we would like to upgrade later - // - // Rationale note: Only handle remote object allows the same mechanism to work for minRPC - // which is needed for wasm and other env that goes through C API + // TODO(tqchen): consider merge with disco protocol const AnyView* any_view_ptr = reinterpret_cast(in); if (const auto* ref = any_view_ptr->as()) { this->template Write(runtime::TypeIndex::kRuntimeRPCObjectRef); uint64_t handle = reinterpret_cast(ref->object_handle()); this->template Write(handle); + } else if (auto opt_str = any_view_ptr->as()) { + this->template Write(ffi::TypeIndex::kTVMFFIStr); + this->template Write((*opt_str).size()); + this->template WriteArray((*opt_str).data(), (*opt_str).size()); + } else if (auto opt_bytes = any_view_ptr->as()) { + this->template Write(ffi::TypeIndex::kTVMFFIBytes); + this->template Write((*opt_bytes).size()); + this->template WriteArray((*opt_bytes).data(), (*opt_bytes).size()); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() @@ -239,6 +251,10 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { const AnyView* any_view_ptr = reinterpret_cast(in); if (any_view_ptr->as()) { return sizeof(uint32_t) + sizeof(int64_t); + } else if (auto opt_str = any_view_ptr->as()) { + return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_str).size(); + } else if (auto opt_bytes = any_view_ptr->as()) { + return sizeof(uint32_t) + sizeof(uint64_t) + (*opt_bytes).size(); } else { LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: " << any_view_ptr->GetTypeKey() << " (type_index = " << any_view_ptr->type_index() @@ -266,7 +282,23 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Legacy ABI translation // TODO(tqchen): remove this once we have upgraded to new ABI *reinterpret_cast(out) = rpc_obj; - object_arena_.push_back(rpc_obj); + any_arena_.emplace_back(rpc_obj); + } else if (type_index == ffi::TypeIndex::kTVMFFIStr) { + uint64_t size; + this->template Read(&size); + std::string data(size, '\0'); + this->template ReadArray(data.data(), size); + ffi::String ret(std::move(data)); + *reinterpret_cast(out) = ret; + any_arena_.emplace_back(ret); + } else if (type_index == ffi::TypeIndex::kTVMFFIBytes) { + uint64_t size; + this->template Read(&size); + std::string data(size, '\0'); + this->template ReadArray(data.data(), size); + ffi::Bytes ret(std::move(data)); + *reinterpret_cast(out) = ret; + any_arena_.emplace_back(ret); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " << Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")"; @@ -285,7 +317,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { /*! \brief Recycle all the memory used in the arena */ void RecycleAll() { - this->object_arena_.clear(); + this->any_arena_.clear(); this->arena_.RecycleAll(); } @@ -310,7 +342,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { // Internal arena support::Arena arena_; // internal arena for temp objects - std::vector object_arena_; + std::vector any_arena_; // State switcher void SwitchToState(State state) {