Skip to content

Commit ea015d8

Browse files
authored
[FFI] Update typeinfo to speedup parent reflection (apache#18083)
This PR updates the typeinfo to speedup parent reflection Also optimizes a few if constexpr branches to explicitly place else to eliminate branch early in compilation.
1 parent b896ad8 commit ea015d8

File tree

9 files changed

+109
-39
lines changed

9 files changed

+109
-39
lines changed

include/tvm/ffi/c_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ typedef struct {
437437
/*!
438438
* \brief Runtime type information for object type checking.
439439
*/
440-
typedef struct {
440+
typedef struct TVMFFITypeInfo {
441441
/*!
442442
*\brief The runtime type index,
443443
* It can be allocated during runtime if the type is dynamic.
@@ -452,7 +452,7 @@ typedef struct {
452452
* \note To keep things simple, we do not allow multiple inheritance so the
453453
* hieracy stays as a tree
454454
*/
455-
const int32_t* type_acenstors;
455+
const struct TVMFFITypeInfo** type_acenstors;
456456
// The following fields are used for reflection
457457
/*! \brief Cached hash value of the type key, used for consistent structural hashing. */
458458
uint64_t type_key_hash;

include/tvm/ffi/container/tuple.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,9 @@ struct TypeTraits<Tuple<Types...>> : public ObjectRefTypeTraitsBase<Tuple<Types.
253253
}
254254
if constexpr (sizeof...(Rest) > 0) {
255255
return TryConvertElements<I + 1, Rest...>(std::move(arr));
256+
} else {
257+
return true;
256258
}
257-
return true;
258259
}
259260

260261
static TVM_FFI_INLINE std::string TypeStr() {

include/tvm/ffi/object.h

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -693,34 +693,38 @@ template <typename TargetType>
693693
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) {
694694
static_assert(std::is_base_of_v<Object, TargetType>);
695695
// Everything is a subclass of object.
696-
if constexpr (std::is_same<TargetType, Object>::value) return true;
697-
698-
if constexpr (TargetType::_type_final) {
696+
if constexpr (std::is_same<TargetType, Object>::value) {
697+
return true;
698+
} else if constexpr (TargetType::_type_final) {
699699
// if the target type is a final type
700700
// then we only need to check the equivalence.
701701
return object_type_index == TargetType::RuntimeTypeIndex();
702-
}
703-
704-
// if target type is a non-leaf type
705-
// Check if type index falls into the range of reserved slots.
706-
int32_t target_type_index = TargetType::RuntimeTypeIndex();
707-
int32_t begin = target_type_index;
708-
// The condition will be optimized by constant-folding.
709-
if constexpr (TargetType::_type_child_slots != 0) {
710-
// total_slots = child_slots + 1 (including self)
711-
int32_t end = begin + TargetType::_type_child_slots + 1;
712-
if (object_type_index >= begin && object_type_index < end) return true;
713702
} else {
714-
if (object_type_index == begin) return true;
715-
}
716-
if (!TargetType::_type_child_slots_can_overflow) return false;
717-
// Invariance: parent index is always smaller than the child.
718-
if (object_type_index < target_type_index) return false;
719-
// Do a runtime lookup of type information
720-
// the function checks that the info exists
721-
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
722-
return (type_info->type_depth > TargetType::_type_depth &&
723-
type_info->type_acenstors[TargetType::_type_depth] == target_type_index);
703+
// Explicitly enclose in else to eliminate this branch early in compilation.
704+
// if target type is a non-leaf type
705+
// Check if type index falls into the range of reserved slots.
706+
int32_t target_type_index = TargetType::RuntimeTypeIndex();
707+
int32_t begin = target_type_index;
708+
// The condition will be optimized by constant-folding.
709+
if constexpr (TargetType::_type_child_slots != 0) {
710+
// total_slots = child_slots + 1 (including self)
711+
int32_t end = begin + TargetType::_type_child_slots + 1;
712+
if (object_type_index >= begin && object_type_index < end) return true;
713+
} else {
714+
if (object_type_index == begin) return true;
715+
}
716+
if constexpr (TargetType::_type_child_slots_can_overflow) {
717+
// Invariance: parent index is always smaller than the child.
718+
if (object_type_index < target_type_index) return false;
719+
// Do a runtime lookup of type information
720+
// the function checks that the info exists
721+
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
722+
return (type_info->type_depth > TargetType::_type_depth &&
723+
type_info->type_acenstors[TargetType::_type_depth]->type_index == target_type_index);
724+
} else {
725+
return false;
726+
}
727+
}
724728
}
725729

726730
/*!

include/tvm/ffi/reflection/reflection.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,31 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) {
392392
return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
393393
}
394394

395+
/*!
396+
* \brief Visit each field info of the type info and run callback.
397+
*
398+
* \tparam Callback The callback function type.
399+
*
400+
* \param type_info The type info.
401+
* \param callback The callback function.
402+
*
403+
* \note This function calls both the child and parent type info.
404+
*/
405+
template <typename Callback>
406+
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
407+
// iterate through acenstors in parent to child order
408+
// skip the first one since it is always the root object
409+
for (int i = 1; i < type_info->type_depth; ++i) {
410+
const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
411+
for (int j = 0; j < parent_info->num_fields; ++j) {
412+
callback(parent_info->fields + j);
413+
}
414+
}
415+
for (int i = 0; i < type_info->num_fields; ++i) {
416+
callback(type_info->fields + i);
417+
}
418+
}
419+
395420
} // namespace reflection
396421
} // namespace ffi
397422
} // namespace tvm

include/tvm/ffi/string.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,15 @@ class String : public ObjectRef {
255255
*/
256256
String(std::string&& other) // NOLINT(*)
257257
: ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}
258+
259+
/*!
260+
* \brief constructor from TVMFFIByteArray
261+
*
262+
* \param other a TVMFFIByteArray.
263+
*/
264+
explicit String(TVMFFIByteArray other)
265+
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data, other.size)) {}
266+
258267
/*!
259268
* \brief Swap this String with another string
260269
* \param other The other string

src/ffi/object.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class TypeTable {
5252
/*! \brief stored type key */
5353
String type_key_data;
5454
/*! \brief acenstor information */
55-
std::vector<int32_t> type_acenstors_data;
55+
std::vector<const TVMFFITypeInfo*> type_acenstors_data;
5656
/*! \brief type fields informaton */
5757
std::vector<TVMFFIFieldInfo> type_fields_data;
5858
/*! \brief type methods informaton */
@@ -85,7 +85,7 @@ class TypeTable {
8585
type_acenstors_data[i] = parent->type_acenstors[i];
8686
}
8787
// set last type information to be parent
88-
type_acenstors_data[parent->type_depth] = parent->type_index;
88+
type_acenstors_data[parent->type_depth] = parent;
8989
}
9090
// initialize type info: no change to type_key and type_acenstors fields
9191
// after this line
@@ -234,7 +234,7 @@ class TypeTable {
234234
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
235235
const Entry* ptr = it->get();
236236
if (ptr != nullptr && ptr->type_depth != 0) {
237-
int parent_index = ptr->type_acenstors[ptr->type_depth - 1];
237+
int parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index;
238238
num_children[parent_index] += num_children[ptr->type_index] + 1;
239239
if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) {
240240
expected_child_slots[ptr->type_index] = ptr->num_slots - 1;
@@ -247,7 +247,7 @@ class TypeTable {
247247
if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) {
248248
std::cerr << '[' << ptr->type_index << "]\t" << ToStringView(ptr->type_key);
249249
if (ptr->type_depth != 0) {
250-
int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1];
250+
int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index;
251251
std::cerr << "\tparent=" << ToStringView(type_table_[parent_index]->type_key);
252252
} else {
253253
std::cerr << "\tparent=root";
@@ -375,9 +375,8 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {
375375

376376
// iterate through acenstors in parent to child order
377377
// skip the first one since it is always the root object
378-
TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
379378
for (int i = 1; i < type_info->type_depth; ++i) {
380-
update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
379+
update_fields(type_info->type_acenstors[i]);
381380
}
382381
update_fields(type_info);
383382

tests/cpp/test_object.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ TEST(Object, TypeInfo) {
5555
EXPECT_TRUE(info != nullptr);
5656
EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex());
5757
EXPECT_EQ(info->type_depth, 2);
58-
EXPECT_EQ(info->type_acenstors[0], Object::_type_index);
59-
EXPECT_EQ(info->type_acenstors[1], TNumberObj::_type_index);
58+
EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index);
59+
EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index);
6060
EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
6161
}
6262

tests/cpp/test_reflection.cc

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
* under the License.
1919
*/
2020
#include <gtest/gtest.h>
21+
#include <tvm/ffi/container/map.h>
2122
#include <tvm/ffi/object.h>
2223
#include <tvm/ffi/reflection/reflection.h>
2324
#include <tvm/ffi/string.h>
@@ -29,11 +30,20 @@ namespace {
2930
using namespace tvm::ffi;
3031
using namespace tvm::ffi::testing;
3132

32-
struct A : public Object {
33+
struct TestObjA : public Object {
3334
int64_t x;
3435
int64_t y;
3536

37+
static constexpr const char* _type_key = "test.TestObjA";
3638
static constexpr bool _type_mutable = true;
39+
TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjA, Object);
40+
};
41+
42+
struct TestObjADerived : public TestObjA {
43+
int64_t z;
44+
45+
static constexpr const char* _type_key = "test.TestObjADerived";
46+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjADerived, TestObjA);
3747
};
3848

3949
TVM_FFI_STATIC_INIT_BLOCK({
@@ -56,12 +66,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
5666
return self->value - other;
5767
});
5868

59-
refl::ObjectDef<A>().def_ro("x", &A::x).def_rw("y", &A::y);
69+
refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y);
70+
refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
6071
});
6172

6273
TEST(Reflection, GetFieldByteOffset) {
63-
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject));
64-
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 + sizeof(TVMFFIObject));
74+
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::x), sizeof(TVMFFIObject));
75+
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::y), 8 + sizeof(TVMFFIObject));
6576
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject));
6677
}
6778

@@ -131,4 +142,15 @@ TEST(Reflection, CallMethod) {
131142
EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
132143
}
133144

145+
TEST(Reflection, ForEachFieldInfo) {
146+
const TypeInfo* info = TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex());
147+
Map<String, int> field_name_to_offset;
148+
reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* field_info) {
149+
field_name_to_offset.Set(String(field_info->name), field_info->offset);
150+
});
151+
EXPECT_EQ(field_name_to_offset["x"], sizeof(TVMFFIObject));
152+
EXPECT_EQ(field_name_to_offset["y"], 8 + sizeof(TVMFFIObject));
153+
EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
154+
}
155+
134156
} // namespace

tests/cpp/testing_object.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <tvm/ffi/memory.h>
2424
#include <tvm/ffi/object.h>
25+
#include <tvm/ffi/reflection/reflection.h>
2526
#include <tvm/ffi/string.h>
2627

2728
namespace tvm {
@@ -81,6 +82,15 @@ class TFloatObj : public TNumberObj {
8182

8283
double Add(double other) const { return value + other; }
8384

85+
static void RegisterReflection() {
86+
namespace refl = tvm::ffi::reflection;
87+
refl::ObjectDef<TFloatObj>()
88+
.def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0))
89+
.def("sub",
90+
[](const TFloatObj* self, double other) -> double { return self->value - other; })
91+
.def("add", &TFloatObj::Add, "add method");
92+
}
93+
8494
static constexpr const char* _type_key = "test.Float";
8595
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
8696
};

0 commit comments

Comments
 (0)