Skip to content

Commit 2ec11f5

Browse files
committed
[FFI][REFACTOR] Enable custom s_hash/equal (apache#18165)
This PR enables custom shash equal via TypeAttr, also enhances the Var comparison by checking content so we can precheck type signatures.
1 parent 162d600 commit 2ec11f5

File tree

9 files changed

+222
-29
lines changed

9 files changed

+222
-29
lines changed

CMakeLists.txt

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ project(
2626

2727
option(TVM_FFI_BUILD_TESTS "Adding test targets." OFF)
2828
option(TVM_FFI_USE_LIBBACKTRACE "Enable libbacktrace" ON)
29+
option(TVM_FFI_USE_EXTRA_CXX_API "Enable extra CXX API in shared lib" ON)
2930
option(TVM_FFI_BACKTRACE_ON_SEGFAULT "Set signal handler to print traceback on segfault" ON)
3031

3132
include(cmake/Utils/CxxWarning.cmake)
@@ -47,7 +48,8 @@ target_include_directories(tvm_ffi_header INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}
4748
target_link_libraries(tvm_ffi_header INTERFACE dlpack_header)
4849

4950
########## Target: `tvm_ffi` ##########
50-
add_library(tvm_ffi_objs OBJECT
51+
52+
set(tvm_ffi_objs_sources
5153
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback.cc"
5254
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/traceback_win.cc"
5355
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/object.cc"
@@ -57,10 +59,18 @@ add_library(tvm_ffi_objs OBJECT
5759
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
5860
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/testing.cc"
5961
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
60-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
61-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
62-
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
6362
)
63+
64+
if (TVM_FFI_USE_EXTRA_CXX_API)
65+
list(APPEND tvm_ffi_objs_sources
66+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/access_path.cc"
67+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_equal.cc"
68+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/reflection/structural_hash.cc"
69+
)
70+
endif()
71+
72+
add_library(tvm_ffi_objs OBJECT ${tvm_ffi_objs_sources})
73+
6474
set_target_properties(
6575
tvm_ffi_objs PROPERTIES
6676
POSITION_INDEPENDENT_CODE ON

include/tvm/ffi/c_api.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,27 @@ typedef enum {
424424
* is only an unique copy of each value.
425425
*/
426426
kTVMFFISEqHashKindUniqueInstance = 5,
427+
/*!
428+
* \brief provide custom __s_equal__ and __s_hash__ functions through TypeAttrColumn.
429+
*
430+
* The function signatures are(defined via ffi::Function)
431+
*
432+
* \code
433+
* bool __s_equal__(
434+
* ObjectRefType self, ObjectRefType other,
435+
* ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string field_name)> cmp,
436+
* );
437+
*
438+
* uint64_t __s_hash__(
439+
* ObjectRefType self, uint64_t type_key_hash,
440+
* ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
441+
* );
442+
* \endcode
443+
*
444+
* Where the extra string field in cmp is the name of the field that is being compared.
445+
* The function should be registered through TVMFFITypeRegisterAttr via reflection::TypeAttrDef.
446+
*/
447+
kTVMFFISEqHashKindCustomTreeNode = 6,
427448
#ifdef __cplusplus
428449
};
429450
#else
@@ -539,7 +560,9 @@ typedef struct {
539560
/*
540561
* \brief Column array that stores extra attributes about types
541562
*
542-
* The attributes stored in column arrays that can be looked up by type index.
563+
* The attributes stored in a column array that can be looked up by type index.
564+
* Note that the TypeAttr behaves like type_traits so column[T] so not contain
565+
* attributes from base classes.
543566
*
544567
* \note
545568
* \sa TVMFFIRegisterTypeAttr

src/ffi/object.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class TypeTable {
247247
column_index = type_attr_columns_.size();
248248
type_attr_columns_.emplace_back(std::make_unique<TypeAttrColumnData>());
249249
type_attr_name_to_column_index_.Set(name_str, column_index);
250+
} else {
251+
column_index = (*it).second;
250252
}
251253
TypeAttrColumnData* column = type_attr_columns_[column_index].get();
252254
if (column->data_.size() < static_cast<size_t>(type_index + 1)) {

src/ffi/reflection/structural_equal.cc

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,7 @@ class StructEqualHandler {
119119
}
120120

121121
bool success = true;
122-
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
123-
// we are in a free var case that is not yet mapped.
124-
// in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be set
125-
if (!lhs.same_as(rhs) && !map_free_vars_) {
126-
success = false;
127-
}
128-
} else {
122+
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
129123
// We recursively compare the fields the object
130124
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* field_info) {
131125
// skip fields that are marked as structural eq hash ignore
@@ -158,11 +152,57 @@ class StructEqualHandler {
158152
return false;
159153
}
160154
});
155+
} else {
156+
static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn("__s_equal__");
157+
// run custom equal function defined via __s_equal__ type attribute
158+
if (s_equal_callback_ == nullptr) {
159+
s_equal_callback_ = ffi::Function::FromTyped(
160+
[this](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) {
161+
// NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially
162+
// and only cast to string if mismatch happens
163+
bool success = true;
164+
if (def_region) {
165+
bool allow_free_var = true;
166+
std::swap(allow_free_var, map_free_vars_);
167+
success = CompareAny(lhs, rhs);
168+
std::swap(allow_free_var, map_free_vars_);
169+
} else {
170+
success = CompareAny(lhs, rhs);
171+
}
172+
if (!success) {
173+
if (mismatch_lhs_reverse_path_ != nullptr) {
174+
String field_name_str = field_name.cast<String>();
175+
mismatch_lhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
176+
mismatch_rhs_reverse_path_->emplace_back(AccessStep::ObjectField(field_name_str));
177+
}
178+
}
179+
return success;
180+
});
181+
}
182+
TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
183+
<< "TypeAttr `__s_equal__` is not registered for type `" << String(type_info->type_key)
184+
<< "`";
185+
success = custom_s_equal[type_info->type_index]
186+
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
187+
.cast<bool>();
161188
}
189+
162190
if (success) {
191+
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
192+
// we are in a free var case that is not yet mapped.
193+
// in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be
194+
// set
195+
if (lhs.same_as(rhs) || map_free_vars_) {
196+
// record the equality
197+
equal_map_lhs_[lhs] = rhs;
198+
equal_map_rhs_[rhs] = lhs;
199+
return true;
200+
} else {
201+
return false;
202+
}
203+
}
163204
// if we have a success mapping and in graph/var mode, record the equality mapping
164-
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode ||
165-
structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
205+
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
166206
// record the equality
167207
equal_map_lhs_[lhs] = rhs;
168208
equal_map_rhs_[rhs] = lhs;
@@ -306,6 +346,8 @@ class StructEqualHandler {
306346
// the root lhs for result printing
307347
std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr;
308348
std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr;
349+
// lazily initialize custom equal function
350+
ffi::Function s_equal_callback_ = nullptr;
309351
// map from lhs to rhs
310352
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
311353
// map from rhs to lhs
@@ -342,6 +384,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
342384
namespace refl = tvm::ffi::reflection;
343385
refl::GlobalDef().def("ffi.reflection.GetFirstStructuralMismatch",
344386
StructuralEqual::GetFirstMismatch);
387+
// ensure the type attribute column is presented in the system even if it is empty.
388+
refl::EnsureTypeAttrColumn("__s_equal__");
345389
});
346390

347391
} // namespace reflection

src/ffi/reflection/structural_hash.cc

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,7 @@ class StructuralHashHandler {
9999

100100
// compute the hash value
101101
uint64_t hash_value = obj->GetTypeKeyHash();
102-
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
103-
if (map_free_vars_) {
104-
// use lexical order of free var and its type
105-
hash_value = details::StableHashCombine(hash_value, free_var_counter_++);
106-
} else {
107-
// Fallback to pointer hash, we are not mapping free var.
108-
return std::hash<const Object*>()(obj.get());
109-
}
110-
} else {
102+
if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
111103
// go over the content and hash the fields
112104
ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
113105
// skip fields that are marked as structural eq hash ignore
@@ -126,12 +118,43 @@ class StructuralHashHandler {
126118
}
127119
}
128120
});
129-
// if it is a DAG node, also record the lexical order of graph counter
130-
// this helps to distinguish DAG from trees.
131-
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
132-
hash_value = details::StableHashCombine(hash_value, graph_node_counter_++);
121+
} else {
122+
static reflection::TypeAttrColumn custom_s_hash = reflection::TypeAttrColumn("__s_hash__");
123+
TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
124+
<< "TypeAttr `__s_hash__` is not registered for type `" << String(type_info->type_key)
125+
<< "`";
126+
if (s_hash_callback_ == nullptr) {
127+
s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool def_region) {
128+
if (def_region) {
129+
bool allow_free_var = true;
130+
std::swap(allow_free_var, map_free_vars_);
131+
uint64_t hash_value = HashAny(val);
132+
std::swap(allow_free_var, map_free_vars_);
133+
return hash_value;
134+
} else {
135+
return HashAny(val);
136+
}
137+
});
138+
}
139+
hash_value = custom_s_hash[type_info->type_index]
140+
.cast<ffi::Function>()(obj, hash_value, s_hash_callback_)
141+
.cast<uint64_t>();
142+
}
143+
144+
if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar) {
145+
if (map_free_vars_) {
146+
// use lexical order of free var and its type
147+
hash_value = details::StableHashCombine(hash_value, free_var_counter_++);
148+
} else {
149+
// Fallback to pointer hash, we are not mapping free var.
150+
hash_value = std::hash<const Object*>()(obj.get());
133151
}
134152
}
153+
// if it is a DAG node, also record the lexical order of graph counter
154+
// this helps to distinguish DAG from trees.
155+
if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode) {
156+
hash_value = details::StableHashCombine(hash_value, graph_node_counter_++);
157+
}
135158
// record the hash value for this object
136159
hash_memo_[obj] = hash_value;
137160
return hash_value;
@@ -244,6 +267,8 @@ class StructuralHashHandler {
244267
uint32_t free_var_counter_{0};
245268
// graph node counter.
246269
uint32_t graph_node_counter_{0};
270+
// lazily initialize custom hash function
271+
ffi::Function s_hash_callback_ = nullptr;
247272
// map from lhs to rhs
248273
std::unordered_map<ObjectRef, uint64_t, ObjectPtrHash, ObjectPtrEqual> hash_memo_;
249274
};
@@ -258,6 +283,7 @@ uint64_t StructuralHash::Hash(const Any& value, bool map_free_vars, bool skip_nd
258283
TVM_FFI_STATIC_INIT_BLOCK({
259284
namespace refl = tvm::ffi::reflection;
260285
refl::GlobalDef().def("ffi.reflection.StructuralHash", StructuralHash::Hash);
286+
refl::EnsureTypeAttrColumn("__s_hash__");
261287
});
262288

263289
} // namespace reflection

tests/cpp/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
file(GLOB _test_sources "${CMAKE_CURRENT_SOURCE_DIR}/test*.cc")
2+
file(GLOB _test_extra_sources "${CMAKE_CURRENT_SOURCE_DIR}/extra/test*.cc")
3+
4+
if (TVM_FFI_USE_EXTRA_CXX_API)
5+
list(APPEND _test_sources ${_test_extra_sources})
6+
endif()
7+
28
add_executable(
39
tvm_ffi_tests
410
EXCLUDE_FROM_ALL

tests/cpp/test_reflection_structural_equal_hash.cc renamed to tests/cpp/extra/test_reflection_structural_equal_hash.cc

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include <tvm/ffi/reflection/structural_hash.h>
2727
#include <tvm/ffi/string.h>
2828

29-
#include "./testing_object.h"
29+
#include "../testing_object.h"
3030

3131
namespace {
3232

@@ -169,4 +169,30 @@ TEST(StructuralEqualHash, FuncDefAndIgnoreField) {
169169
EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
170170
}
171171

172+
TEST(StructuralEqualHash, CustomTreeNode) {
173+
TVar x = TVar("x");
174+
TVar y = TVar("y");
175+
// comment fields are ignored
176+
TCustomFunc fa = TCustomFunc({x}, {TInt(1), x}, "comment a");
177+
TCustomFunc fb = TCustomFunc({y}, {TInt(1), y}, "comment b");
178+
179+
TCustomFunc fc = TCustomFunc({x}, {TInt(1), TInt(2)}, "comment c");
180+
181+
EXPECT_TRUE(refl::StructuralEqual()(fa, fb));
182+
EXPECT_EQ(refl::StructuralHash()(fa), refl::StructuralHash()(fb));
183+
184+
EXPECT_FALSE(refl::StructuralEqual()(fa, fc));
185+
auto diff_fa_fc = refl::StructuralEqual::GetFirstMismatch(fa, fc);
186+
auto expected_diff_fa_fc = refl::AccessPathPair(refl::AccessPath({
187+
refl::AccessStep::ObjectField("body"),
188+
refl::AccessStep::ArrayIndex(1),
189+
}),
190+
refl::AccessPath({
191+
refl::AccessStep::ObjectField("body"),
192+
refl::AccessStep::ArrayIndex(1),
193+
}));
194+
EXPECT_TRUE(diff_fa_fc.has_value());
195+
EXPECT_TRUE(refl::StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
196+
}
197+
172198
} // namespace

tests/cpp/test_reflection_accessor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
5555
TPrimExprObj::RegisterReflection();
5656
TVarObj::RegisterReflection();
5757
TFuncObj::RegisterReflection();
58+
TCustomFuncObj::RegisterReflection();
5859

5960
refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", &TestObjA::y);
6061
refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);

tests/cpp/testing_object.h

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ class TVarObj : public Object {
158158

159159
static void RegisterReflection() {
160160
namespace refl = tvm::ffi::reflection;
161-
refl::ObjectDef<TVarObj>().def_ro("name", &TVarObj::name);
161+
refl::ObjectDef<TVarObj>().def_ro("name", &TVarObj::name,
162+
refl::AttachFieldFlag::SEqHashIgnore());
162163
}
163164

164165
static constexpr const char* _type_key = "test.Var";
@@ -204,6 +205,60 @@ class TFunc : public ObjectRef {
204205
TVM_FFI_DEFINE_OBJECT_REF_METHODS(TFunc, ObjectRef, TFuncObj);
205206
};
206207

208+
class TCustomFuncObj : public Object {
209+
public:
210+
Array<TVar> params;
211+
Array<ObjectRef> body;
212+
String comment;
213+
214+
TCustomFuncObj(Array<TVar> params, Array<ObjectRef> body, String comment)
215+
: params(params), body(body), comment(comment) {}
216+
217+
bool SEqual(const TCustomFuncObj* other,
218+
ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> cmp) const {
219+
if (!cmp(params, other->params, true, "params")) {
220+
std::cout << "custom s_equal failed params" << std::endl;
221+
return false;
222+
}
223+
if (!cmp(body, other->body, false, "body")) {
224+
std::cout << "custom s_equal failed body" << std::endl;
225+
return false;
226+
}
227+
return true;
228+
}
229+
230+
uint64_t SHash(uint64_t type_key_hash, ffi::TypedFunction<uint64_t(AnyView, bool)> hash) const {
231+
uint64_t hash_value = type_key_hash;
232+
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(params, true));
233+
hash_value = tvm::ffi::details::StableHashCombine(hash_value, hash(body, false));
234+
return hash_value;
235+
}
236+
237+
static void RegisterReflection() {
238+
namespace refl = tvm::ffi::reflection;
239+
refl::ObjectDef<TCustomFuncObj>()
240+
.def_ro("params", &TCustomFuncObj::params)
241+
.def_ro("body", &TCustomFuncObj::body)
242+
.def_ro("comment", &TCustomFuncObj::comment);
243+
refl::TypeAttrDef<TCustomFuncObj>()
244+
.def("__s_equal__", &TCustomFuncObj::SEqual)
245+
.def("__s_hash__", &TCustomFuncObj::SHash);
246+
}
247+
248+
static constexpr const char* _type_key = "test.CustomFunc";
249+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindCustomTreeNode;
250+
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TCustomFuncObj, Object);
251+
};
252+
253+
class TCustomFunc : public ObjectRef {
254+
public:
255+
explicit TCustomFunc(Array<TVar> params, Array<ObjectRef> body, String comment) {
256+
data_ = make_object<TCustomFuncObj>(params, body, comment);
257+
}
258+
259+
TVM_FFI_DEFINE_OBJECT_REF_METHODS(TCustomFunc, ObjectRef, TCustomFuncObj);
260+
};
261+
207262
} // namespace testing
208263

209264
template <>

0 commit comments

Comments
 (0)