Skip to content

Commit 39ffe0a

Browse files
authored
[TVMScript] Add object path tracing to StructuralEqual (#12101)
Motivation: when two IR objects fail a structural equality check, currently there is no easy way to find out which part of the IR caused the mismatch. In this PR, we modify the `StructuralEqual` infrastructure to also optionally return a pair of `ObjectPath` objects that point to the mismatch. (See #11977). In the upcoming PRs, we will pass these paths to the TIR printer, so that it could highlight the mismatch location nicely. Tracking issue: #11912
1 parent 85624ff commit 39ffe0a

File tree

11 files changed

+969
-44
lines changed

11 files changed

+969
-44
lines changed

include/tvm/node/reflection.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr
404404
}
405405
}
406406

407+
/*!
408+
* \brief Given an object and an address of its attribute, return the key of the attribute.
409+
* \return nullptr if no attribute with the given address exists.
410+
*/
411+
Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
412+
407413
} // namespace tvm
408414
#endif // TVM_NODE_REFLECTION_H_

include/tvm/node/structural_equal.h

Lines changed: 138 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#define TVM_NODE_STRUCTURAL_EQUAL_H_
2525

2626
#include <tvm/node/functor.h>
27+
#include <tvm/node/object_path.h>
2728
#include <tvm/runtime/container/array.h>
2829
#include <tvm/runtime/data_type.h>
2930

@@ -56,6 +57,27 @@ class BaseValueEqual {
5657
}
5758
};
5859

60+
/*!
61+
* \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
62+
*/
63+
class ObjectPathPairNode : public Object {
64+
public:
65+
ObjectPath lhs_path;
66+
ObjectPath rhs_path;
67+
68+
ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
69+
70+
static constexpr const char* _type_key = "ObjectPathPair";
71+
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
72+
};
73+
74+
class ObjectPathPair : public ObjectRef {
75+
public:
76+
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
77+
78+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
79+
};
80+
5981
/*!
6082
* \brief Content-aware structural equality comparator for objects.
6183
*
@@ -99,7 +121,10 @@ class StructuralEqual : public BaseValueEqual {
99121
* equality checking. Instead, it can store the necessary equality conditions
100122
* and check later via an internally managed stack.
101123
*/
102-
class SEqualReducer : public BaseValueEqual {
124+
class SEqualReducer {
125+
private:
126+
struct PathTracingData;
127+
103128
public:
104129
/*! \brief Internal handler that defines custom behaviors.. */
105130
class Handler {
@@ -110,12 +135,24 @@ class SEqualReducer : public BaseValueEqual {
110135
* \param lhs The left operand.
111136
* \param rhs The right operand.
112137
* \param map_free_vars Whether do we allow remap variables if possible.
138+
* \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
113139
*
114140
* \return false if there is an immediate failure, true otherwise.
115141
* \note This function may save the equality condition of (lhs == rhs) in an internal
116142
* stack and try to resolve later.
117143
*/
118-
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
144+
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
145+
const Optional<ObjectPathPair>& current_paths) = 0;
146+
147+
/*!
148+
* \brief Mark the comparison as failed, but don't fail immediately.
149+
*
150+
* This is useful for producing better error messages when comparing containers.
151+
* For example, if two array sizes mismatch, it's better to mark the comparison as failed
152+
* but compare array elements anyway, so that we could find the true first mismatch.
153+
*/
154+
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
155+
119156
/*!
120157
* \brief Lookup the graph node equal map for vars that are already mapped.
121158
*
@@ -129,28 +166,72 @@ class SEqualReducer : public BaseValueEqual {
129166
* \brief Mark current comparison as graph node equal comparison.
130167
*/
131168
virtual void MarkGraphNode() = 0;
132-
};
133169

134-
using BaseValueEqual::operator();
170+
protected:
171+
using PathTracingData = SEqualReducer::PathTracingData;
172+
};
135173

136174
/*! \brief default constructor */
137175
SEqualReducer() = default;
138176
/*!
139177
* \brief Constructor with a specific handler.
140178
* \param handler The equal handler for objects.
179+
* \param tracing_data Optional pointer to the path tracing data.
141180
* \param map_free_vars Whether or not to map free variables.
142181
*/
143-
explicit SEqualReducer(Handler* handler, bool map_free_vars)
144-
: handler_(handler), map_free_vars_(map_free_vars) {}
182+
explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
183+
: handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
184+
185+
/*!
186+
* \brief Reduce condition to comparison of two attribute values.
187+
* \param lhs The left operand.
188+
* \param rhs The right operand.
189+
* \return the immediate check result.
190+
*/
191+
bool operator()(const double& lhs, const double& rhs) const;
192+
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
193+
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
194+
bool operator()(const int& lhs, const int& rhs) const;
195+
bool operator()(const bool& lhs, const bool& rhs) const;
196+
bool operator()(const std::string& lhs, const std::string& rhs) const;
197+
bool operator()(const DataType& lhs, const DataType& rhs) const;
198+
199+
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
200+
bool operator()(const ENum& lhs, const ENum& rhs) const {
201+
using Underlying = typename std::underlying_type<ENum>::type;
202+
static_assert(std::is_same<Underlying, int>::value,
203+
"Enum must have `int` as the underlying type");
204+
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
205+
}
206+
207+
/*!
208+
* \brief Reduce condition to comparison of two objects.
209+
* \param lhs The left operand.
210+
* \param rhs The right operand.
211+
* \return the immediate check result.
212+
*/
213+
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
214+
145215
/*!
146216
* \brief Reduce condition to comparison of two objects.
217+
*
218+
* Like `operator()`, but with an additional `paths` parameter that specifies explicit object
219+
* paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
220+
* objects like Array and Map, or other custom objects that store nested objects that are not
221+
* simply attributes.
222+
*
223+
* Can only be called when `IsPathTracingEnabled()` is `true`.
224+
*
147225
* \param lhs The left operand.
148226
* \param rhs The right operand.
227+
* \param paths Object paths for `lhs` and `rhs`.
149228
* \return the immediate check result.
150229
*/
151-
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
152-
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
230+
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
231+
ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
232+
return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
153233
}
234+
154235
/*!
155236
* \brief Reduce condition to comparison of two definitions,
156237
* where free vars can be mapped.
@@ -162,9 +243,8 @@ class SEqualReducer : public BaseValueEqual {
162243
* \param rhs The right operand.
163244
* \return the immediate check result.
164245
*/
165-
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
166-
return handler_->SEqualReduce(lhs, rhs, true);
167-
}
246+
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
247+
168248
/*!
169249
* \brief Reduce condition to comparison of two arrays.
170250
* \param lhs The left operand.
@@ -173,13 +253,20 @@ class SEqualReducer : public BaseValueEqual {
173253
*/
174254
template <typename T>
175255
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
176-
// quick specialization for Array to reduce amount of recursion
177-
// depth as array comparison is pretty common.
178-
if (lhs.size() != rhs.size()) return false;
179-
for (size_t i = 0; i < lhs.size(); ++i) {
180-
if (!(operator()(lhs[i], rhs[i]))) return false;
256+
if (tracing_data_ == nullptr) {
257+
// quick specialization for Array to reduce amount of recursion
258+
// depth as array comparison is pretty common.
259+
if (lhs.size() != rhs.size()) return false;
260+
for (size_t i = 0; i < lhs.size(); ++i) {
261+
if (!(operator()(lhs[i], rhs[i]))) return false;
262+
}
263+
return true;
181264
}
182-
return true;
265+
266+
// If tracing is enabled, fall back to the regular path
267+
const ObjectRef& lhs_obj = lhs;
268+
const ObjectRef& rhs_obj = rhs;
269+
return (*this)(lhs_obj, rhs_obj);
183270
}
184271
/*!
185272
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
@@ -198,11 +285,43 @@ class SEqualReducer : public BaseValueEqual {
198285
/*! \return Get the internal handler. */
199286
Handler* operator->() const { return handler_; }
200287

288+
/*! \brief Check if this reducer is tracing paths to the first mismatch. */
289+
bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
290+
291+
/*!
292+
* \brief Get the paths of the currently compared objects.
293+
*
294+
* Can only be called when `IsPathTracingEnabled()` is true.
295+
*/
296+
const ObjectPathPair& GetCurrentObjectPaths() const;
297+
298+
/*!
299+
* \brief Specify the object paths of a detected mismatch.
300+
*
301+
* Can only be called when `IsPathTracingEnabled()` is true.
302+
*/
303+
void RecordMismatchPaths(const ObjectPathPair& paths) const;
304+
201305
private:
306+
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
307+
308+
bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
309+
const ObjectPathPair* paths) const;
310+
311+
static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
312+
const void* rhs_address,
313+
const PathTracingData* tracing_data);
314+
315+
template <typename T>
316+
static bool CompareAttributeValues(const T& lhs, const T& rhs,
317+
const PathTracingData* tracing_data);
318+
202319
/*! \brief Internal class pointer. */
203-
Handler* handler_;
320+
Handler* handler_ = nullptr;
321+
/*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
322+
const PathTracingData* tracing_data_ = nullptr;
204323
/*! \brief Whether or not to map free vars. */
205-
bool map_free_vars_;
324+
bool map_free_vars_ = false;
206325
};
207326

208327
} // namespace tvm

python/tvm/ir/base.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
191191
The left operand.
192192
193193
map_free_vars : bool
194-
Whether or not shall we map free vars that does
195-
not bound to any definitions as equal to each other.
194+
Whether free variables (i.e. variables without a definition site) should be mapped
195+
as equal to each other.
196196
197197
Return
198198
------
@@ -209,6 +209,36 @@ def structural_equal(lhs, rhs, map_free_vars=False):
209209
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
210210

211211

212+
def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
213+
"""Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.
214+
215+
Parameters
216+
----------
217+
lhs : Object
218+
The left operand.
219+
220+
rhs : Object
221+
The left operand.
222+
223+
map_free_vars : bool
224+
Whether free variables (i.e. variables without a definition site) should be mapped
225+
as equal to each other.
226+
227+
Returns
228+
-------
229+
mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
230+
`None` if `lhs` and `rhs` are structurally equal.
231+
Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
232+
"""
233+
lhs = tvm.runtime.convert(lhs)
234+
rhs = tvm.runtime.convert(rhs)
235+
mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
236+
if mismatch is None:
237+
return None
238+
else:
239+
return mismatch.lhs_path, mismatch.rhs_path
240+
241+
212242
def assert_structural_equal(lhs, rhs, map_free_vars=False):
213243
"""Assert lhs and rhs are structurally equal to each other.
214244

python/tvm/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# class exposures
2020
from .packed_func import PackedFunc
2121
from .object import Object
22+
from .object_path import ObjectPath, ObjectPathPair
2223
from .object_generic import ObjectGeneric, ObjectTypes
2324
from .ndarray import NDArray, DataType, DataTypeCode, Device
2425
from .module import Module, num_threads

python/tvm/runtime/object_path.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"MissingArrayElementPath",
3535
"MapValuePath",
3636
"MissingMapEntryPath",
37+
"ObjectPathPair",
3738
)
3839

3940

@@ -122,3 +123,18 @@ class MapValuePath(ObjectPath):
122123
@tvm._ffi.register_object("MissingMapEntryPath")
123124
class MissingMapEntryPath(ObjectPath):
124125
pass
126+
127+
128+
@tvm._ffi.register_object("ObjectPathPair")
129+
class ObjectPathPair(Object):
130+
"""
131+
Pair of ObjectPaths, one for each object being tested for structural equality.
132+
"""
133+
134+
@property
135+
def lhs_path(self) -> ObjectPath:
136+
return _ffi_node_api.ObjectPathPairLhsPath(self)
137+
138+
@property
139+
def rhs_path(self) -> ObjectPath:
140+
return _ffi_node_api.ObjectPathPairRhsPath(self)

src/node/reflection.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,4 +281,48 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
281281
TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
282282

283283
TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
284+
285+
namespace {
286+
// Attribute visitor class for finding the attribute key by its address
287+
class GetAttrKeyByAddressVisitor : public AttrVisitor {
288+
public:
289+
explicit GetAttrKeyByAddressVisitor(const void* attr_address)
290+
: attr_address_(attr_address), key_(nullptr) {}
291+
292+
void Visit(const char* key, double* value) final { DoVisit(key, value); }
293+
void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
294+
void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
295+
void Visit(const char* key, int* value) final { DoVisit(key, value); }
296+
void Visit(const char* key, bool* value) final { DoVisit(key, value); }
297+
void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
298+
void Visit(const char* key, void** value) final { DoVisit(key, value); }
299+
void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
300+
void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
301+
void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }
302+
303+
const char* GetKey() const { return key_; }
304+
305+
private:
306+
const void* attr_address_;
307+
const char* key_;
308+
309+
void DoVisit(const char* key, const void* candidate) {
310+
if (attr_address_ == candidate) {
311+
key_ = key;
312+
}
313+
}
314+
};
315+
} // anonymous namespace
316+
317+
Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
318+
GetAttrKeyByAddressVisitor visitor(attr_address);
319+
ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
320+
const char* key = visitor.GetKey();
321+
if (key == nullptr) {
322+
return NullOpt;
323+
} else {
324+
return String(key);
325+
}
326+
}
327+
284328
} // namespace tvm

0 commit comments

Comments
 (0)