2424#define TVM_NODE_STRUCTURAL_EQUAL_H_
2525
2626#include < tvm/node/functor.h>
27+ #include < tvm/node/object_path.h>
2728#include < tvm/runtime/container/array.h>
2829#include < tvm/runtime/data_type.h>
2930
@@ -56,6 +57,27 @@ class BaseValueEqual {
5657 }
5758};
5859
60+ /* !
61+ * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
62+ */
63+ class ObjectPathPairNode : public Object {
64+ public:
65+ ObjectPath lhs_path;
66+ ObjectPath rhs_path;
67+
68+ ObjectPathPairNode (ObjectPath lhs_path, ObjectPath rhs_path);
69+
70+ static constexpr const char * _type_key = " ObjectPathPair" ;
71+ TVM_DECLARE_FINAL_OBJECT_INFO (ObjectPathPairNode, Object);
72+ };
73+
74+ class ObjectPathPair : public ObjectRef {
75+ public:
76+ ObjectPathPair (ObjectPath lhs_path, ObjectPath rhs_path);
77+
78+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (ObjectPathPair, ObjectRef, ObjectPathPairNode);
79+ };
80+
5981/* !
6082 * \brief Content-aware structural equality comparator for objects.
6183 *
@@ -99,7 +121,10 @@ class StructuralEqual : public BaseValueEqual {
99121 * equality checking. Instead, it can store the necessary equality conditions
100122 * and check later via an internally managed stack.
101123 */
102- class SEqualReducer : public BaseValueEqual {
124+ class SEqualReducer {
125+ private:
126+ struct PathTracingData ;
127+
103128 public:
104129 /* ! \brief Internal handler that defines custom behaviors.. */
105130 class Handler {
@@ -110,12 +135,24 @@ class SEqualReducer : public BaseValueEqual {
110135 * \param lhs The left operand.
111136 * \param rhs The right operand.
112137 * \param map_free_vars Whether do we allow remap variables if possible.
138+ * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
113139 *
114140 * \return false if there is an immediate failure, true otherwise.
115141 * \note This function may save the equality condition of (lhs == rhs) in an internal
116142 * stack and try to resolve later.
117143 */
118- virtual bool SEqualReduce (const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
144+ virtual bool SEqualReduce (const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
145+ const Optional<ObjectPathPair>& current_paths) = 0;
146+
147+ /* !
148+ * \brief Mark the comparison as failed, but don't fail immediately.
149+ *
150+ * This is useful for producing better error messages when comparing containers.
151+ * For example, if two array sizes mismatch, it's better to mark the comparison as failed
152+ * but compare array elements anyway, so that we could find the true first mismatch.
153+ */
154+ virtual void DeferFail (const ObjectPathPair& mismatch_paths) = 0;
155+
119156 /* !
120157 * \brief Lookup the graph node equal map for vars that are already mapped.
121158 *
@@ -129,28 +166,72 @@ class SEqualReducer : public BaseValueEqual {
129166 * \brief Mark current comparison as graph node equal comparison.
130167 */
131168 virtual void MarkGraphNode () = 0;
132- };
133169
134- using BaseValueEqual::operator ();
170+ protected:
171+ using PathTracingData = SEqualReducer::PathTracingData;
172+ };
135173
136174 /* ! \brief default constructor */
137175 SEqualReducer () = default ;
138176 /* !
139177 * \brief Constructor with a specific handler.
140178 * \param handler The equal handler for objects.
179+ * \param tracing_data Optional pointer to the path tracing data.
141180 * \param map_free_vars Whether or not to map free variables.
142181 */
143- explicit SEqualReducer (Handler* handler, bool map_free_vars)
144- : handler_(handler), map_free_vars_(map_free_vars) {}
182+ explicit SEqualReducer (Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
183+ : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
184+
185+ /* !
186+ * \brief Reduce condition to comparison of two attribute values.
187+ * \param lhs The left operand.
188+ * \param rhs The right operand.
189+ * \return the immediate check result.
190+ */
191+ bool operator ()(const double & lhs, const double & rhs) const ;
192+ bool operator ()(const int64_t & lhs, const int64_t & rhs) const ;
193+ bool operator ()(const uint64_t & lhs, const uint64_t & rhs) const ;
194+ bool operator ()(const int & lhs, const int & rhs) const ;
195+ bool operator ()(const bool & lhs, const bool & rhs) const ;
196+ bool operator ()(const std::string& lhs, const std::string& rhs) const ;
197+ bool operator ()(const DataType& lhs, const DataType& rhs) const ;
198+
199+ template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
200+ bool operator ()(const ENum& lhs, const ENum& rhs) const {
201+ using Underlying = typename std::underlying_type<ENum>::type;
202+ static_assert (std::is_same<Underlying, int >::value,
203+ " Enum must have `int` as the underlying type" );
204+ return EnumAttrsEqual (static_cast <int >(lhs), static_cast <int >(rhs), &lhs, &rhs);
205+ }
206+
207+ /* !
208+ * \brief Reduce condition to comparison of two objects.
209+ * \param lhs The left operand.
210+ * \param rhs The right operand.
211+ * \return the immediate check result.
212+ */
213+ bool operator ()(const ObjectRef& lhs, const ObjectRef& rhs) const ;
214+
145215 /* !
146216 * \brief Reduce condition to comparison of two objects.
217+ *
218+ * Like `operator()`, but with an additional `paths` parameter that specifies explicit object
219+ * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
220+ * objects like Array and Map, or other custom objects that store nested objects that are not
221+ * simply attributes.
222+ *
223+ * Can only be called when `IsPathTracingEnabled()` is `true`.
224+ *
147225 * \param lhs The left operand.
148226 * \param rhs The right operand.
227+ * \param paths Object paths for `lhs` and `rhs`.
149228 * \return the immediate check result.
150229 */
151- bool operator ()(const ObjectRef& lhs, const ObjectRef& rhs) const {
152- return handler_->SEqualReduce (lhs, rhs, map_free_vars_);
230+ bool operator ()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
231+ ICHECK (IsPathTracingEnabled ()) << " Path tracing must be enabled when calling this function" ;
232+ return ObjectAttrsEqual (lhs, rhs, map_free_vars_, &paths);
153233 }
234+
154235 /* !
155236 * \brief Reduce condition to comparison of two definitions,
156237 * where free vars can be mapped.
@@ -162,9 +243,8 @@ class SEqualReducer : public BaseValueEqual {
162243 * \param rhs The right operand.
163244 * \return the immediate check result.
164245 */
165- bool DefEqual (const ObjectRef& lhs, const ObjectRef& rhs) {
166- return handler_->SEqualReduce (lhs, rhs, true );
167- }
246+ bool DefEqual (const ObjectRef& lhs, const ObjectRef& rhs);
247+
168248 /* !
169249 * \brief Reduce condition to comparison of two arrays.
170250 * \param lhs The left operand.
@@ -173,13 +253,20 @@ class SEqualReducer : public BaseValueEqual {
173253 */
174254 template <typename T>
175255 bool operator ()(const Array<T>& lhs, const Array<T>& rhs) const {
176- // quick specialization for Array to reduce amount of recursion
177- // depth as array comparison is pretty common.
178- if (lhs.size () != rhs.size ()) return false ;
179- for (size_t i = 0 ; i < lhs.size (); ++i) {
180- if (!(operator ()(lhs[i], rhs[i]))) return false ;
256+ if (tracing_data_ == nullptr ) {
257+ // quick specialization for Array to reduce amount of recursion
258+ // depth as array comparison is pretty common.
259+ if (lhs.size () != rhs.size ()) return false ;
260+ for (size_t i = 0 ; i < lhs.size (); ++i) {
261+ if (!(operator ()(lhs[i], rhs[i]))) return false ;
262+ }
263+ return true ;
181264 }
182- return true ;
265+
266+ // If tracing is enabled, fall back to the regular path
267+ const ObjectRef& lhs_obj = lhs;
268+ const ObjectRef& rhs_obj = rhs;
269+ return (*this )(lhs_obj, rhs_obj);
183270 }
184271 /* !
185272 * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
@@ -198,11 +285,43 @@ class SEqualReducer : public BaseValueEqual {
198285 /* ! \return Get the internal handler. */
199286 Handler* operator ->() const { return handler_; }
200287
288+ /* ! \brief Check if this reducer is tracing paths to the first mismatch. */
289+ bool IsPathTracingEnabled () const { return tracing_data_ != nullptr ; }
290+
291+ /* !
292+ * \brief Get the paths of the currently compared objects.
293+ *
294+ * Can only be called when `IsPathTracingEnabled()` is true.
295+ */
296+ const ObjectPathPair& GetCurrentObjectPaths () const ;
297+
298+ /* !
299+ * \brief Specify the object paths of a detected mismatch.
300+ *
301+ * Can only be called when `IsPathTracingEnabled()` is true.
302+ */
303+ void RecordMismatchPaths (const ObjectPathPair& paths) const ;
304+
201305 private:
306+ bool EnumAttrsEqual (int lhs, int rhs, const void * lhs_address, const void * rhs_address) const ;
307+
308+ bool ObjectAttrsEqual (const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
309+ const ObjectPathPair* paths) const ;
310+
311+ static void GetPathsFromAttrAddressesAndStoreMismatch (const void * lhs_address,
312+ const void * rhs_address,
313+ const PathTracingData* tracing_data);
314+
315+ template <typename T>
316+ static bool CompareAttributeValues (const T& lhs, const T& rhs,
317+ const PathTracingData* tracing_data);
318+
202319 /* ! \brief Internal class pointer. */
203- Handler* handler_;
320+ Handler* handler_ = nullptr ;
321+ /* ! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
322+ const PathTracingData* tracing_data_ = nullptr ;
204323 /* ! \brief Whether or not to map free vars. */
205- bool map_free_vars_;
324+ bool map_free_vars_ = false ;
206325};
207326
208327} // namespace tvm
0 commit comments