Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ffi/include/tvm/ffi/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ typedef struct {
/*!
* \brief Runtime type information for object type checking.
*/
typedef struct {
typedef struct TVMFFITypeInfo {
/*!
*\brief The runtime type index,
* It can be allocated during runtime if the type is dynamic.
Expand All @@ -452,7 +452,7 @@ typedef struct {
* \note To keep things simple, we do not allow multiple inheritance so the
* hieracy stays as a tree
*/
const int32_t* type_acenstors;
const struct TVMFFITypeInfo** type_acenstors;
// The following fields are used for reflection
/*! \brief Cached hash value of the type key, used for consistent structural hashing. */
uint64_t type_key_hash;
Expand Down
3 changes: 2 additions & 1 deletion ffi/include/tvm/ffi/container/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,9 @@ struct TypeTraits<Tuple<Types...>> : public ObjectRefTypeTraitsBase<Tuple<Types.
}
if constexpr (sizeof...(Rest) > 0) {
return TryConvertElements<I + 1, Rest...>(std::move(arr));
} else {
return true;
}
return true;
}

static TVM_FFI_INLINE std::string TypeStr() {
Expand Down
52 changes: 28 additions & 24 deletions ffi/include/tvm/ffi/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -693,34 +693,38 @@ template <typename TargetType>
TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) {
static_assert(std::is_base_of_v<Object, TargetType>);
// Everything is a subclass of object.
if constexpr (std::is_same<TargetType, Object>::value) return true;

if constexpr (TargetType::_type_final) {
if constexpr (std::is_same<TargetType, Object>::value) {
return true;
} else if constexpr (TargetType::_type_final) {
// if the target type is a final type
// then we only need to check the equivalence.
return object_type_index == TargetType::RuntimeTypeIndex();
}

// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
int32_t target_type_index = TargetType::RuntimeTypeIndex();
int32_t begin = target_type_index;
// The condition will be optimized by constant-folding.
if constexpr (TargetType::_type_child_slots != 0) {
// total_slots = child_slots + 1 (including self)
int32_t end = begin + TargetType::_type_child_slots + 1;
if (object_type_index >= begin && object_type_index < end) return true;
} else {
if (object_type_index == begin) return true;
}
if (!TargetType::_type_child_slots_can_overflow) return false;
// Invariance: parent index is always smaller than the child.
if (object_type_index < target_type_index) return false;
// Do a runtime lookup of type information
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
return (type_info->type_depth > TargetType::_type_depth &&
type_info->type_acenstors[TargetType::_type_depth] == target_type_index);
// Explicitly enclose in else to eliminate this branch early in compilation.
// if target type is a non-leaf type
// Check if type index falls into the range of reserved slots.
int32_t target_type_index = TargetType::RuntimeTypeIndex();
int32_t begin = target_type_index;
// The condition will be optimized by constant-folding.
if constexpr (TargetType::_type_child_slots != 0) {
// total_slots = child_slots + 1 (including self)
int32_t end = begin + TargetType::_type_child_slots + 1;
if (object_type_index >= begin && object_type_index < end) return true;
} else {
if (object_type_index == begin) return true;
}
if constexpr (TargetType::_type_child_slots_can_overflow) {
// Invariance: parent index is always smaller than the child.
if (object_type_index < target_type_index) return false;
// Do a runtime lookup of type information
// the function checks that the info exists
const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
return (type_info->type_depth > TargetType::_type_depth &&
type_info->type_acenstors[TargetType::_type_depth]->type_index == target_type_index);
} else {
return false;
}
}
}

/*!
Expand Down
25 changes: 25 additions & 0 deletions ffi/include/tvm/ffi/reflection/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,31 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) {
return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
}

/*!
* \brief Visit each field info of the type info and run callback.
*
* \tparam Callback The callback function type.
*
* \param type_info The type info.
* \param callback The callback function.
*
* \note This function calls both the child and parent type info.
*/
template <typename Callback>
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
for (int j = 0; j < parent_info->num_fields; ++j) {
callback(parent_info->fields + j);
}
}
for (int i = 0; i < type_info->num_fields; ++i) {
callback(type_info->fields + i);
}
}

} // namespace reflection
} // namespace ffi
} // namespace tvm
Expand Down
9 changes: 9 additions & 0 deletions ffi/include/tvm/ffi/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ class String : public ObjectRef {
*/
String(std::string&& other) // NOLINT(*)
: ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}

/*!
* \brief constructor from TVMFFIByteArray
*
* \param other a TVMFFIByteArray.
*/
explicit String(TVMFFIByteArray other)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data, other.size)) {}

/*!
* \brief Swap this String with another string
* \param other The other string
Expand Down
11 changes: 5 additions & 6 deletions ffi/src/ffi/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TypeTable {
/*! \brief stored type key */
String type_key_data;
/*! \brief acenstor information */
std::vector<int32_t> type_acenstors_data;
std::vector<const TVMFFITypeInfo*> type_acenstors_data;
/*! \brief type fields informaton */
std::vector<TVMFFIFieldInfo> type_fields_data;
/*! \brief type methods informaton */
Expand Down Expand Up @@ -85,7 +85,7 @@ class TypeTable {
type_acenstors_data[i] = parent->type_acenstors[i];
}
// set last type information to be parent
type_acenstors_data[parent->type_depth] = parent->type_index;
type_acenstors_data[parent->type_depth] = parent;
}
// initialize type info: no change to type_key and type_acenstors fields
// after this line
Expand Down Expand Up @@ -234,7 +234,7 @@ class TypeTable {
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
const Entry* ptr = it->get();
if (ptr != nullptr && ptr->type_depth != 0) {
int parent_index = ptr->type_acenstors[ptr->type_depth - 1];
int parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index;
num_children[parent_index] += num_children[ptr->type_index] + 1;
if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) {
expected_child_slots[ptr->type_index] = ptr->num_slots - 1;
Expand All @@ -247,7 +247,7 @@ class TypeTable {
if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) {
std::cerr << '[' << ptr->type_index << "]\t" << ToStringView(ptr->type_key);
if (ptr->type_depth != 0) {
int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1];
int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]->type_index;
std::cerr << "\tparent=" << ToStringView(type_table_[parent_index]->type_key);
} else {
std::cerr << "\tparent=root";
Expand Down Expand Up @@ -375,9 +375,8 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) {

// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
TVM_FFI_ICHECK(type_info->type_acenstors[0] == TypeIndex::kTVMFFIObject);
for (int i = 1; i < type_info->type_depth; ++i) {
update_fields(TVMFFIGetTypeInfo(type_info->type_acenstors[i]));
update_fields(type_info->type_acenstors[i]);
}
update_fields(type_info);

Expand Down
4 changes: 2 additions & 2 deletions ffi/tests/cpp/test_object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ TEST(Object, TypeInfo) {
EXPECT_TRUE(info != nullptr);
EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex());
EXPECT_EQ(info->type_depth, 2);
EXPECT_EQ(info->type_acenstors[0], Object::_type_index);
EXPECT_EQ(info->type_acenstors[1], TNumberObj::_type_index);
EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index);
EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index);
EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
}

Expand Down
30 changes: 26 additions & 4 deletions ffi/tests/cpp/test_reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* under the License.
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>
Expand All @@ -29,11 +30,20 @@ namespace {
using namespace tvm::ffi;
using namespace tvm::ffi::testing;

struct A : public Object {
struct TestObjA : public Object {
int64_t x;
int64_t y;

static constexpr const char* _type_key = "test.TestObjA";
static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_BASE_OBJECT_INFO(TestObjA, Object);
};

struct TestObjADerived : public TestObjA {
int64_t z;

static constexpr const char* _type_key = "test.TestObjADerived";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TestObjADerived, TestObjA);
};

TVM_FFI_STATIC_INIT_BLOCK({
Expand All @@ -56,12 +66,13 @@ TVM_FFI_STATIC_INIT_BLOCK({
return self->value - other;
});

refl::ObjectDef<A>().def_ro("x", &A::x).def_rw("y", &A::y);
refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y);
refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
});

TEST(Reflection, GetFieldByteOffset) {
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject));
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 + sizeof(TVMFFIObject));
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::x), sizeof(TVMFFIObject));
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TestObjA::y), 8 + sizeof(TVMFFIObject));
EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject));
}

Expand Down Expand Up @@ -131,4 +142,15 @@ TEST(Reflection, CallMethod) {
EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
}

TEST(Reflection, ForEachFieldInfo) {
const TypeInfo* info = TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex());
Map<String, int> field_name_to_offset;
reflection::ForEachFieldInfo(info, [&](const TVMFFIFieldInfo* field_info) {
field_name_to_offset.Set(String(field_info->name), field_info->offset);
});
EXPECT_EQ(field_name_to_offset["x"], sizeof(TVMFFIObject));
EXPECT_EQ(field_name_to_offset["y"], 8 + sizeof(TVMFFIObject));
EXPECT_EQ(field_name_to_offset["z"], 16 + sizeof(TVMFFIObject));
}

} // namespace
10 changes: 10 additions & 0 deletions ffi/tests/cpp/testing_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>

namespace tvm {
Expand Down Expand Up @@ -81,6 +82,15 @@ class TFloatObj : public TNumberObj {

double Add(double other) const { return value + other; }

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TFloatObj>()
.def_ro("value", &TFloatObj::value, "float value field", refl::DefaultValue(10.0))
.def("sub",
[](const TFloatObj* self, double other) -> double { return self->value - other; })
.def("add", &TFloatObj::Add, "add method");
}

static constexpr const char* _type_key = "test.Float";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
};
Expand Down
Loading