@@ -119,13 +119,7 @@ class StructEqualHandler {
119119 }
120120
121121 bool success = true ;
122- if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar ) {
123- // we are in a free var case that is not yet mapped.
124- // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be set
125- if (!lhs.same_as (rhs) && !map_free_vars_) {
126- success = false ;
127- }
128- } else {
122+ if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode ) {
129123 // We recursively compare the fields the object
130124 ForEachFieldInfoWithEarlyStop (type_info, [&](const TVMFFIFieldInfo* field_info) {
131125 // skip fields that are marked as structural eq hash ignore
@@ -158,11 +152,57 @@ class StructEqualHandler {
158152 return false ;
159153 }
160154 });
155+ } else {
156+ static reflection::TypeAttrColumn custom_s_equal = reflection::TypeAttrColumn (" __s_equal__" );
157+ // run custom equal function defined via __s_equal__ type attribute
158+ if (s_equal_callback_ == nullptr ) {
159+ s_equal_callback_ = ffi::Function::FromTyped (
160+ [this ](AnyView lhs, AnyView rhs, bool def_region, AnyView field_name) {
161+ // NOTE: we explicitly make field_name as AnyView to avoid copy overhead initially
162+ // and only cast to string if mismatch happens
163+ bool success = true ;
164+ if (def_region) {
165+ bool allow_free_var = true ;
166+ std::swap (allow_free_var, map_free_vars_);
167+ success = CompareAny (lhs, rhs);
168+ std::swap (allow_free_var, map_free_vars_);
169+ } else {
170+ success = CompareAny (lhs, rhs);
171+ }
172+ if (!success) {
173+ if (mismatch_lhs_reverse_path_ != nullptr ) {
174+ String field_name_str = field_name.cast <String>();
175+ mismatch_lhs_reverse_path_->emplace_back (AccessStep::ObjectField (field_name_str));
176+ mismatch_rhs_reverse_path_->emplace_back (AccessStep::ObjectField (field_name_str));
177+ }
178+ }
179+ return success;
180+ });
181+ }
182+ TVM_FFI_ICHECK (custom_s_equal[type_info->type_index ] != nullptr )
183+ << " TypeAttr `__s_equal__` is not registered for type `" << String (type_info->type_key )
184+ << " `" ;
185+ success = custom_s_equal[type_info->type_index ]
186+ .cast <ffi::Function>()(lhs, rhs, s_equal_callback_)
187+ .cast <bool >();
161188 }
189+
162190 if (success) {
191+ if (structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar ) {
192+ // we are in a free var case that is not yet mapped.
193+ // in this case, either map_free_vars_ should be set to true, or map_free_vars_ should be
194+ // set
195+ if (lhs.same_as (rhs) || map_free_vars_) {
196+ // record the equality
197+ equal_map_lhs_[lhs] = rhs;
198+ equal_map_rhs_[rhs] = lhs;
199+ return true ;
200+ } else {
201+ return false ;
202+ }
203+ }
163204 // if we have a success mapping and in graph/var mode, record the equality mapping
164- if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode ||
165- structural_eq_hash_kind == kTVMFFISEqHashKindFreeVar ) {
205+ if (structural_eq_hash_kind == kTVMFFISEqHashKindDAGNode ) {
166206 // record the equality
167207 equal_map_lhs_[lhs] = rhs;
168208 equal_map_rhs_[rhs] = lhs;
@@ -306,6 +346,8 @@ class StructEqualHandler {
306346 // the root lhs for result printing
307347 std::vector<AccessStep>* mismatch_lhs_reverse_path_ = nullptr ;
308348 std::vector<AccessStep>* mismatch_rhs_reverse_path_ = nullptr ;
349+ // lazily initialize custom equal function
350+ ffi::Function s_equal_callback_ = nullptr ;
309351 // map from lhs to rhs
310352 std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
311353 // map from rhs to lhs
@@ -342,6 +384,8 @@ TVM_FFI_STATIC_INIT_BLOCK({
342384 namespace refl = tvm::ffi::reflection;
343385 refl::GlobalDef ().def (" ffi.reflection.GetFirstStructuralMismatch" ,
344386 StructuralEqual::GetFirstMismatch);
387+ // ensure the type attribute column is presented in the system even if it is empty.
388+ refl::EnsureTypeAttrColumn (" __s_equal__" );
345389});
346390
347391} // namespace reflection
0 commit comments