Skip to content

Commit 46eac56

Browse files
authored
[FFI][ABI] Introduce weak rc support (#18259)
This PR adds weak ref counter support to the FFI ABI. Weak rc is useful when we want to break cyclic dependencies. - When a strong rc goes to zero, we call the destructor of the object, but not freeing the memory - When both strong and weak rc goes to zero, we call the memory free operation The weak rc mechanism is useful when we want to break cyclic dependencies in object, where the weak rc can keep memory alive but the destructor is called. As of now, because we deliberately avoid cyles in codebase, we do not have strong use-case for weak rc. However, given weak rc is common practice in shared_ptr, Rust RC, and also used in torch's c10::intrusive_ptr. It is better to make sure the ABI is future compatible to such use-cases before we freeze. This PR implements weak rc as a u32 counter and strong rc as a u64 counter, with the following design consideration. - Weak rc is very rarely used and u32 is sufficient. - Keeping weak rc in u32 allows us to keep object header size to 24 bytes, saving extra 8 bytes(considering alignment) We also need to update deleter to take flags that consider both weak and strong deletion events. The implementation tries to optimize common case where both strong and weak goes to 0 at the same time and call deleter once with both flags set.
1 parent b67650f commit 46eac56

File tree

16 files changed

+475
-58
lines changed

16 files changed

+475
-58
lines changed

ffi/include/tvm/ffi/c_api.h

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,36 @@ typedef enum {
156156
/*! \brief Handle to Object from C API's pov */
157157
typedef void* TVMFFIObjectHandle;
158158

159+
/*!
160+
* \brief bitmask of the object deleter flag.
161+
*/
162+
#ifdef __cplusplus
163+
enum TVMFFIObjectDeleterFlagBitMask : int32_t {
164+
#else
165+
typedef enum {
166+
#endif
167+
/*!
168+
* \brief deleter action when strong reference count becomes zero.
169+
* Need to call destructor of the object but not free the memory block.
170+
*/
171+
kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0,
172+
/*!
173+
* \brief deleter action when weak reference count becomes zero.
174+
* Need to free the memory block.
175+
*/
176+
kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1,
177+
/*!
178+
* \brief deleter action when both strong and weak reference counts become zero.
179+
* \note This is the most common case.
180+
*/
181+
kTVMFFIObjectDeleterFlagBitMaskBoth =
182+
(kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak),
183+
#ifdef __cplusplus
184+
};
185+
#else
186+
} TVMFFIObjectDeleterFlagBitMask;
187+
#endif
188+
159189
/*!
160190
* \brief C-based type of all FFI object header that allocates on heap.
161191
* \note TVMFFIObject and TVMFFIAny share the common type_index header
@@ -166,11 +196,22 @@ typedef struct TVMFFIObject {
166196
* \note The type index of Object and Any are shared in FFI.
167197
*/
168198
int32_t type_index;
169-
/*! \brief Reference counter of the object. */
170-
int32_t ref_counter;
199+
/*!
200+
* \brief Weak reference counter of the object, for compatiblity with weak_ptr design.
201+
* \note Use u32 to ensure that overall object stays within 24-byte boundary, usually
202+
* manipulation of weak counter is less common than strong counter.
203+
*/
204+
uint32_t weak_ref_count;
205+
/*! \brief Strong reference counter of the object. */
206+
uint64_t strong_ref_count;
171207
union {
172-
/*! \brief Deleter to be invoked when reference counter goes to zero. */
173-
void (*deleter)(struct TVMFFIObject* self);
208+
/*!
209+
* \brief Deleter to be invoked when strong reference counter goes to zero.
210+
* \param self The self object handle.
211+
* \param flags The flags to indicate deletion behavior.
212+
* \sa TVMFFIObjectDeleterFlagBitMask
213+
*/
214+
void (*deleter)(struct TVMFFIObject* self, int flags);
174215
/*!
175216
* \brief auxilary field to TVMFFIObject is always 8 bytes aligned.
176217
* \note This helps us to ensure cross platform compatibility.
@@ -307,13 +348,19 @@ typedef struct {
307348
// Section: Basic object API
308349
//------------------------------------------------------------
309350
/*!
310-
* \brief Free an object handle by decreasing reference
351+
* \brief Increas the strong reference count of an object handle
352+
* \param obj The object handle.
353+
* \note Internally we increase the reference counter of the object.
354+
* \return 0 when success, nonzero when failure happens
355+
*/
356+
TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj);
357+
358+
/*!
359+
* \brief Free an object handle by decreasing strong reference
311360
* \param obj The object handle.
312-
* \note Internally we decrease the reference counter of the object.
313-
* The object will be freed when every reference to the object are removed.
314361
* \return 0 when success, nonzero when failure happens
315362
*/
316-
TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj);
363+
TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj);
317364

318365
/*!
319366
* \brief Convert type key to type index.
@@ -470,7 +517,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray* str, DLDataType*
470517
* \param dtype The DLDataType to convert.
471518
* \param out The output string.
472519
* \return 0 when success, nonzero when failure happens
473-
* \note out is a String object that needs to be freed by the caller via TVMFFIObjectFree.
520+
* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef.
474521
The content of string can be accessed via TVMFFIObjectGetByteArrayPtr.
475522
476523
* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues.

ffi/include/tvm/ffi/memory.h

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace tvm {
3333
namespace ffi {
3434

3535
/*! \brief Deleter function for obeject */
36-
typedef void (*FObjectDeleter)(TVMFFIObject* obj);
36+
typedef void (*FObjectDeleter)(TVMFFIObject* obj, int flags);
3737

3838
/*!
3939
* \brief Allocate an object using default allocator.
@@ -75,7 +75,8 @@ class ObjAllocatorBase {
7575
static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
7676
T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
7777
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
78-
ffi_ptr->ref_counter = 1;
78+
ffi_ptr->strong_ref_count = 1;
79+
ffi_ptr->weak_ref_count = 1;
7980
ffi_ptr->type_index = T::RuntimeTypeIndex();
8081
ffi_ptr->deleter = Handler::Deleter();
8182
return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr);
@@ -96,7 +97,8 @@ class ObjAllocatorBase {
9697
ArrayType* ptr =
9798
Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...);
9899
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
99-
ffi_ptr->ref_counter = 1;
100+
ffi_ptr->strong_ref_count = 1;
101+
ffi_ptr->weak_ref_count = 1;
100102
ffi_ptr->type_index = ArrayType::RuntimeTypeIndex();
101103
ffi_ptr->deleter = Handler::Deleter();
102104
return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr);
@@ -136,14 +138,18 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
136138
static FObjectDeleter Deleter() { return Deleter_; }
137139

138140
private:
139-
static void Deleter_(TVMFFIObject* objptr) {
141+
static void Deleter_(TVMFFIObject* objptr, int flags) {
140142
T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(objptr);
141-
// It is important to do tptr->T::~T(),
142-
// so that we explicitly call the specific destructor
143-
// instead of tptr->~T(), which could mean the intention
144-
// call a virtual destructor(which may not be available and is not required).
145-
tptr->T::~T();
146-
delete reinterpret_cast<StorageType*>(tptr);
143+
if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
144+
// It is important to do tptr->T::~T(),
145+
// so that we explicitly call the specific destructor
146+
// instead of tptr->~T(), which could mean the intention
147+
// call a virtual destructor(which may not be available and is not required).
148+
tptr->T::~T();
149+
}
150+
if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) {
151+
delete reinterpret_cast<StorageType*>(tptr);
152+
}
147153
}
148154
};
149155

@@ -182,15 +188,19 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
182188
static FObjectDeleter Deleter() { return Deleter_; }
183189

184190
private:
185-
static void Deleter_(TVMFFIObject* objptr) {
191+
static void Deleter_(TVMFFIObject* objptr, int flags) {
186192
ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(objptr);
187-
// It is important to do tptr->ArrayType::~ArrayType(),
188-
// so that we explicitly call the specific destructor
189-
// instead of tptr->~ArrayType(), which could mean the intention
190-
// call a virtual destructor(which may not be available and is not required).
191-
tptr->ArrayType::~ArrayType();
192-
StorageType* p = reinterpret_cast<StorageType*>(tptr);
193-
delete[] p;
193+
if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
194+
// It is important to do tptr->ArrayType::~ArrayType(),
195+
// so that we explicitly call the specific destructor
196+
// instead of tptr->~ArrayType(), which could mean the intention
197+
// call a virtual destructor(which may not be available and is not required).
198+
tptr->ArrayType::~ArrayType();
199+
}
200+
if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) {
201+
StorageType* p = reinterpret_cast<StorageType*>(tptr);
202+
delete[] p;
203+
}
194204
}
195205
};
196206
};

0 commit comments

Comments
 (0)