Skip to content

Commit 3c7c515

Browse files
authored
[FFI] Introduce FFI reflection support in python (#18065)
This PR brings up new reflection support in python. The new reflection now directly attaches property and methods to the class object themselves, making more efficient accessing than old mechanism. It will also support broader set of value types that are compatible with the FFI system. For now the old mechanism and new mechanism will co-exist, and we will phase out old mechanism as we migrate most needed features into new one.
1 parent 43e7676 commit 3c7c515

File tree

14 files changed

+519
-69
lines changed

14 files changed

+519
-69
lines changed

ffi/include/tvm/ffi/memory.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ObjAllocatorBase {
7070
* \param args The arguments.
7171
*/
7272
template <typename T, typename... Args>
73-
inline ObjectPtr<T> make_object(Args&&... args) {
73+
ObjectPtr<T> make_object(Args&&... args) {
7474
using Handler = typename Derived::template Handler<T>;
7575
static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object");
7676
T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...);
@@ -89,7 +89,7 @@ class ObjAllocatorBase {
8989
* \param args The arguments.
9090
*/
9191
template <typename ArrayType, typename ElemType, typename... Args>
92-
inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
92+
ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
9393
using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>;
9494
static_assert(std::is_base_of<Object, ArrayType>::value,
9595
"make_inplace_array can only be used to create Object");
@@ -109,7 +109,9 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
109109
template <typename T>
110110
class Handler {
111111
public:
112-
using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type;
112+
struct alignas(T) StorageType {
113+
char data[sizeof(T)];
114+
};
113115

114116
template <typename... Args>
115117
static T* New(SimpleObjAllocator*, Args&&... args) {

ffi/include/tvm/ffi/reflection/reflection.h

Lines changed: 98 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class DefaultValue : public FieldInfoTrait {
4646
public:
4747
explicit DefaultValue(Any value) : value_(value) {}
4848

49-
void Apply(TVMFFIFieldInfo* info) const {
49+
TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const {
5050
info->default_value = AnyView(value_).CopyToTVMFFIAny();
5151
info->flags |= kTVMFFIFieldFlagBitMaskHasDefault;
5252
}
@@ -65,16 +65,89 @@ class DefaultValue : public FieldInfoTrait {
6565
* \returns The byteoffset
6666
*/
6767
template <typename Class, typename T>
68-
inline int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
68+
TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) {
6969
int64_t field_offset_to_class =
7070
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
7171
return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
7272
}
7373

74+
class ReflectionDefBase {
75+
protected:
76+
template <typename T>
77+
static int FieldGetter(void* field, TVMFFIAny* result) {
78+
TVM_FFI_SAFE_CALL_BEGIN();
79+
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
80+
TVM_FFI_SAFE_CALL_END();
81+
}
82+
83+
template <typename T>
84+
static int FieldSetter(void* field, const TVMFFIAny* value) {
85+
TVM_FFI_SAFE_CALL_BEGIN();
86+
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
87+
TVM_FFI_SAFE_CALL_END();
88+
}
89+
90+
template <typename T>
91+
static int ObjectCreatorDefault(TVMFFIObjectHandle* result) {
92+
TVM_FFI_SAFE_CALL_BEGIN();
93+
ObjectPtr<T> obj = make_object<T>();
94+
*result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj));
95+
TVM_FFI_SAFE_CALL_END();
96+
}
97+
98+
template <typename T>
99+
static TVM_FFI_INLINE void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
100+
if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
101+
value.Apply(info);
102+
}
103+
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
104+
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
105+
}
106+
}
107+
108+
template <typename T>
109+
static TVM_FFI_INLINE void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
110+
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
111+
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
112+
}
113+
}
114+
115+
template <typename T>
116+
static TVM_FFI_INLINE void ApplyExtraInfoTrait(TVMFFITypeExtraInfo* info, const T& value) {
117+
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
118+
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
119+
}
120+
}
121+
template <typename Class, typename R, typename... Args>
122+
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...)) {
123+
auto fwrap = [func](const Class* target, Args... params) -> R {
124+
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
125+
};
126+
return ffi::Function::FromTyped(fwrap, name);
127+
}
128+
129+
template <typename Class, typename R, typename... Args>
130+
static TVM_FFI_INLINE Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
131+
auto fwrap = [func](const Class* target, Args... params) -> R {
132+
return (target->*func)(std::forward<Args>(params)...);
133+
};
134+
return ffi::Function::FromTyped(fwrap, name);
135+
}
136+
137+
template <typename Class, typename Func>
138+
static TVM_FFI_INLINE Function GetMethod(std::string name, Func&& func) {
139+
return ffi::Function::FromTyped(std::forward<Func>(func), name);
140+
}
141+
};
142+
74143
template <typename Class>
75-
class ObjectDef {
144+
class ObjectDef : public ReflectionDefBase {
76145
public:
77-
ObjectDef() : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {}
146+
template <typename... ExtraArgs>
147+
explicit ObjectDef(ExtraArgs&&... extra_args)
148+
: type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) {
149+
RegisterExtraInfo(std::forward<ExtraArgs>(extra_args)...);
150+
}
78151

79152
/*!
80153
* \brief Define a readonly field.
@@ -90,7 +163,7 @@ class ObjectDef {
90163
* \return The reflection definition.
91164
*/
92165
template <typename T, typename... Extra>
93-
ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
166+
TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) {
94167
RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
95168
return *this;
96169
}
@@ -109,7 +182,8 @@ class ObjectDef {
109182
* \return The reflection definition.
110183
*/
111184
template <typename T, typename... Extra>
112-
ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
185+
TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) {
186+
static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields");
113187
RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
114188
return *this;
115189
}
@@ -127,7 +201,7 @@ class ObjectDef {
127201
* \return The reflection definition.
128202
*/
129203
template <typename Func, typename... Extra>
130-
ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
204+
TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) {
131205
RegisterMethod(name, false, std::forward<Func>(func), std::forward<Extra>(extra)...);
132206
return *this;
133207
}
@@ -145,12 +219,26 @@ class ObjectDef {
145219
* \return The reflection definition.
146220
*/
147221
template <typename Func, typename... Extra>
148-
ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
222+
TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) {
149223
RegisterMethod(name, true, std::forward<Func>(func), std::forward<Extra>(extra)...);
150224
return *this;
151225
}
152226

153227
private:
228+
template <typename... ExtraArgs>
229+
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
230+
TVMFFITypeExtraInfo info;
231+
info.total_size = sizeof(Class);
232+
info.creator = nullptr;
233+
info.doc = TVMFFIByteArray{nullptr, 0};
234+
if constexpr (std::is_default_constructible_v<Class>) {
235+
info.creator = ObjectCreatorDefault<Class>;
236+
}
237+
// apply extra info traits
238+
((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
239+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info));
240+
}
241+
154242
template <typename T, typename... ExtraArgs>
155243
void RegisterField(const char* name, T Class::*field_ptr, bool writable,
156244
ExtraArgs&&... extra_args) {
@@ -178,30 +266,6 @@ class ObjectDef {
178266
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
179267
}
180268

181-
template <typename T>
182-
static int FieldGetter(void* field, TVMFFIAny* result) {
183-
TVM_FFI_SAFE_CALL_BEGIN();
184-
*result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast<T*>(field)));
185-
TVM_FFI_SAFE_CALL_END();
186-
}
187-
188-
template <typename T>
189-
static int FieldSetter(void* field, const TVMFFIAny* value) {
190-
TVM_FFI_SAFE_CALL_BEGIN();
191-
*reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value).cast<T>();
192-
TVM_FFI_SAFE_CALL_END();
193-
}
194-
195-
template <typename T>
196-
static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
197-
if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
198-
value.Apply(info);
199-
}
200-
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
201-
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
202-
}
203-
}
204-
205269
// register a method
206270
template <typename Func, typename... Extra>
207271
void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) {
@@ -214,41 +278,14 @@ class ObjectDef {
214278
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
215279
}
216280
// obtain the method function
217-
Function method = GetMethod(std::string(type_key_) + "." + name, std::forward<Func>(func));
281+
Function method =
282+
GetMethod<Class>(std::string(type_key_) + "." + name, std::forward<Func>(func));
218283
info.method = AnyView(method).CopyToTVMFFIAny();
219284
// apply method info traits
220285
((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
221286
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
222287
}
223288

224-
template <typename T>
225-
static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
226-
if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
227-
info->doc = TVMFFIByteArray{value, std::char_traits<char>::length(value)};
228-
}
229-
}
230-
231-
template <typename R, typename... Args>
232-
static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
233-
auto fwrap = [func](const Class* target, Args... params) -> R {
234-
return (const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
235-
};
236-
return ffi::Function::FromTyped(fwrap, name);
237-
}
238-
239-
template <typename R, typename... Args>
240-
static Function GetMethod(std::string name, R (Class::*func)(Args...) const) {
241-
auto fwrap = [func](const Class* target, Args... params) -> R {
242-
return (target->*func)(std::forward<Args>(params)...);
243-
};
244-
return ffi::Function::FromTyped(fwrap, name);
245-
}
246-
247-
template <typename Func>
248-
static Function GetMethod(std::string name, Func&& func) {
249-
return ffi::Function::FromTyped(std::forward<Func>(func), name);
250-
}
251-
252289
int32_t type_index_;
253290
const char* type_key_;
254291
};

ffi/include/tvm/ffi/string.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,18 @@ class String : public ObjectRef {
306306
return Bytes::memncmp(data(), other, size(), std::strlen(other));
307307
}
308308

309+
/*!
310+
* \brief Compares this to other
311+
*
312+
* \param other The TVMFFIByteArray to compare with.
313+
*
314+
* \return zero if both char sequences compare equal. negative if this appear
315+
* before other, positive otherwise.
316+
*/
317+
int compare(const TVMFFIByteArray& other) const {
318+
return Bytes::memncmp(data(), other.data, size(), other.size);
319+
}
320+
309321
/*!
310322
* \brief Returns a pointer to the char array in the string.
311323
*

ffi/src/ffi/object.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,83 @@ class TypeTable {
315315
Map<String, int64_t> type_key2index_;
316316
std::vector<Any> any_pool_;
317317
};
318+
319+
void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
320+
String type_key = args[0].cast<String>();
321+
TVM_FFI_ICHECK(args.size() % 2 == 1);
322+
323+
int32_t type_index;
324+
TVMFFIByteArray type_key_array = TVMFFIByteArray{type_key.data(), type_key.size()};
325+
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
326+
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
327+
if (type_info == nullptr) {
328+
TVM_FFI_THROW(RuntimeError) << "Cannot find type `" << type_key << "`";
329+
}
330+
331+
if (type_info->extra_info == nullptr || type_info->extra_info->creator == nullptr) {
332+
TVM_FFI_THROW(RuntimeError) << "Type `" << type_key << "` does not support reflection creation";
333+
}
334+
TVMFFIObjectHandle handle;
335+
TVM_FFI_CHECK_SAFE_CALL(type_info->extra_info->creator(&handle));
336+
ObjectPtr<Object> ptr =
337+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle));
338+
339+
std::vector<String> keys;
340+
std::vector<bool> keys_found;
341+
342+
for (int i = 1; i < args.size(); i += 2) {
343+
keys.push_back(args[i].cast<String>());
344+
}
345+
keys_found.resize(keys.size(), false);
346+
347+
auto search_field = [&](const TVMFFIByteArray& field_name) {
348+
for (size_t i = 0; i < keys.size(); ++i) {
349+
if (keys_found[i]) continue;
350+
if (keys[i].compare(field_name) == 0) {
351+
return i;
352+
}
353+
}
354+
return keys.size();
355+
};
356+
357+
auto update_fields = [&](const TVMFFITypeInfo* tinfo) {
358+
for (int i = 0; i < tinfo->num_fields; ++i) {
359+
const TVMFFIFieldInfo* field_info = tinfo->fields + i;
360+
size_t arg_index = search_field(field_info->name);
361+
void* field_addr = reinterpret_cast<char*>(ptr.get()) + field_info->offset;
362+
if (arg_index < keys.size()) {
363+
AnyView field_value = args[arg_index * 2 + 2];
364+
field_info->setter(field_addr, reinterpret_cast<const TVMFFIAny*>(&field_value));
365+
keys_found[arg_index] = true;
366+
} else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) {
367+
field_info->setter(field_addr, &(field_info->default_value));
368+
} else {
369+
TVM_FFI_THROW(TypeError) << "Required field `"
370+
<< String(field_info->name.data, field_info->name.size)
371+
<< "` not set in type `" << type_key << "`";
372+
}
373+
}
374+
};
375+
376+
// iterate through acenstors in parent to child order
377+
// skip the first one since it is always the root object
378+
TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
379+
for (int i = 1; i < type_info->type_depth; ++i) {
380+
update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
381+
}
382+
update_fields(type_info);
383+
384+
for (size_t i = 0; i < keys.size(); ++i) {
385+
if (!keys_found[i]) {
386+
TVM_FFI_THROW(TypeError) << "Type `" << type_key << "` does not have field `" << keys[i]
387+
<< "`";
388+
}
389+
}
390+
*ret = ObjectRef(ptr);
391+
}
392+
393+
TVM_FFI_REGISTER_GLOBAL("ffi.MakeObjectFromPackedArgs").set_body_packed(MakeObjectFromPackedArgs);
394+
318395
} // namespace ffi
319396
} // namespace tvm
320397

0 commit comments

Comments
 (0)