Skip to content

Commit f7aeaf1

Browse files
authored
[TVMScript] Connect assert_structural_equal with new TVMScript printer (#13859)
This PR refactors the output of `assert_structural_equal`. Different from the directly printing mismatching nodes, in the old version, the improved one will print the whole scripts, with mismatching nodes underlined. And we print the `ObjectPath` to the mismatching nodes for further better debug. For example, we have following functions ```python @T.prim_func def func1(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) @T.prim_func def func2(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 256)) ``` the log of `assert_structural_equal(func1, func2)` will be like ```python ValueError: StructuralEqual check failed, caused by lhs at <root>.buffer_map[b].shape[1].value: # from tvm.script import tir as T @T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) ^^^ T.evaluate(0) and rhs at <root>.buffer_map[b].shape[1].value: # from tvm.script import tir as T @T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 256)) ^^^ T.evaluate(0) ``` instead of ```python ValueError: StructuralEqual check failed, caused by lhs: 128 and rhs: 256 ``` which is not readable sometimes.
1 parent 98008c2 commit f7aeaf1

File tree

15 files changed

+268
-48
lines changed

15 files changed

+268
-48
lines changed

include/tvm/node/script_printer.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@ class PrinterConfigNode : public Object {
5353
DataType float_dtype = DataType::Void();
5454
/*! \brief Whether or not to verbose print expressions. */
5555
bool verbose_expr = false;
56-
/* \brief Number of spaces used for indentation*/
56+
/*! \brief Number of spaces used for indentation*/
5757
int indent_spaces = 4;
58-
/* \brief Whether to print line numbers */
58+
/*! \brief Whether to print line numbers */
5959
bool print_line_numbers = false;
60-
/* \brief Number of context lines to print around the underlined text */
60+
/*! \brief Number of context lines to print around the underlined text */
6161
int num_context_lines = -1;
62-
/* \brief Object path to be underlined */
62+
/*! \brief Object path to be underlined */
6363
Optional<ObjectPath> path_to_underline = NullOpt;
64+
/*! \brief Whether to output with syntax sugar, set false for complete printing. */
65+
bool syntax_sugar = true;
6466

6567
void VisitAttrs(AttrVisitor* v) {
6668
v->Visit("ir_prefix", &ir_prefix);
@@ -72,6 +74,7 @@ class PrinterConfigNode : public Object {
7274
v->Visit("print_line_numbers", &print_line_numbers);
7375
v->Visit("num_context_lines", &num_context_lines);
7476
v->Visit("path_to_underline", &path_to_underline);
77+
v->Visit("syntax_sugar", &syntax_sugar);
7578
}
7679

7780
static constexpr const char* _type_key = "node.PrinterConfig";

include/tvm/node/structural_equal.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ class SEqualReducer {
153153
*/
154154
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
155155

156+
/*!
157+
* \brief Check if fail defferal is enabled.
158+
*
159+
* \return false if the fail deferral is not enabled, true otherwise.
160+
*/
161+
virtual bool IsFailDeferralEnabled() = 0;
162+
156163
/*!
157164
* \brief Lookup the graph node equal map for vars that are already mapped.
158165
*
@@ -331,12 +338,14 @@ class SEqualReducer {
331338
*/
332339
class SEqualHandlerDefault : public SEqualReducer::Handler {
333340
public:
334-
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch);
341+
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
342+
bool defer_fails);
335343
virtual ~SEqualHandlerDefault();
336344

337345
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
338346
const Optional<ObjectPathPair>& current_paths) override;
339347
void DeferFail(const ObjectPathPair& mismatch_paths) override;
348+
bool IsFailDeferralEnabled() override;
340349
ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
341350
void MarkGraphNode() override;
342351

python/tvm/runtime/script_printer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class PrinterConfig(Object):
3939
print_line_numbers: bool
4040
num_context_lines: int
4141
path_to_underline: Optional[ObjectPath]
42+
syntax_sugar: bool
4243

4344
def __init__(
4445
self,
@@ -54,6 +55,7 @@ def __init__(
5455
print_line_numbers: bool = False,
5556
num_context_lines: Optional[int] = None,
5657
path_to_underline: Optional[ObjectPath] = None,
58+
syntax_sugar: bool = True,
5759
) -> None:
5860
if num_context_lines is None:
5961
num_context_lines = -1
@@ -71,6 +73,7 @@ def __init__(
7173
"print_line_numbers": print_line_numbers,
7274
"num_context_lines": num_context_lines,
7375
"path_to_underline": path_to_underline,
76+
"syntax_sugar": syntax_sugar,
7477
},
7578
)
7679

@@ -96,6 +99,7 @@ def script(
9699
print_line_numbers: bool = False,
97100
num_context_lines: int = -1,
98101
path_to_underline: Optional[ObjectPath] = None,
102+
syntax_sugar: bool = True,
99103
) -> str:
100104
"""Print TVM IR into TVMScript text format
101105
@@ -123,6 +127,8 @@ def script(
123127
The number of lines of context to print before and after the line to underline.
124128
path_to_underline : Optional[ObjectPath] = None
125129
Object path to be underlined
130+
syntax_sugar: bool = True
131+
Whether to output with syntax sugar, set false for complete printing.
126132
127133
Returns
128134
-------
@@ -143,6 +149,7 @@ def script(
143149
print_line_numbers=print_line_numbers,
144150
num_context_lines=num_context_lines,
145151
path_to_underline=path_to_underline,
152+
syntax_sugar=syntax_sugar,
146153
),
147154
)
148155

@@ -162,6 +169,7 @@ def show(
162169
print_line_numbers: bool = False,
163170
num_context_lines: int = -1,
164171
path_to_underline: Optional[ObjectPath] = None,
172+
syntax_sugar: bool = True,
165173
) -> None:
166174
"""A sugar for print highlighted TVM script.
167175
@@ -194,6 +202,8 @@ def show(
194202
The number of lines of context to print before and after the line to underline.
195203
path_to_underline : Optional[ObjectPath] = None
196204
Object path to be underlined
205+
syntax_sugar: bool = True
206+
Whether to output with syntax sugar, set false for complete printing.
197207
"""
198208
from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel
199209
cprint,
@@ -212,6 +222,7 @@ def show(
212222
print_line_numbers=print_line_numbers,
213223
num_context_lines=num_context_lines,
214224
path_to_underline=path_to_underline,
225+
syntax_sugar=syntax_sugar,
215226
),
216227
style=style,
217228
black_format=black_format,

src/meta_schedule/module_equality.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ModuleEqualityStructural : public ModuleEquality {
3838

3939
class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
4040
public:
41-
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {}
41+
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {}
4242

4343
protected:
4444
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,

src/node/script_printer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
2929
}
3030

3131
std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) {
32+
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
33+
return AsLegacyRepr(node);
34+
}
3235
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
3336
}
3437

@@ -67,6 +70,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
6770
if (auto v = config_dict.Get("path_to_underline")) {
6871
n->path_to_underline = Downcast<ObjectPath>(v);
6972
}
73+
if (auto v = config_dict.Get("syntax_sugar")) {
74+
n->syntax_sugar = Downcast<IntImm>(v)->value;
75+
}
7076
this->data_ = std::move(n);
7177
}
7278

src/node/structural_equal.cc

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,12 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
202202
*/
203203
class SEqualHandlerDefault::Impl {
204204
public:
205-
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
206-
: parent_(parent), assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
205+
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
206+
bool defer_fails)
207+
: parent_(parent),
208+
assert_mode_(assert_mode),
209+
first_mismatch_(first_mismatch),
210+
defer_fails_(defer_fails) {}
207211

208212
bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
209213
const Optional<ObjectPathPair>& current_paths) {
@@ -245,6 +249,8 @@ class SEqualHandlerDefault::Impl {
245249
pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
246250
}
247251

252+
bool IsFailDeferralEnabled() { return defer_fails_; }
253+
248254
void MarkGraphNode() {
249255
// need to push to pending tasks in this case
250256
ICHECK(!allow_push_to_stack_ && !task_stack_.empty());
@@ -264,6 +270,8 @@ class SEqualHandlerDefault::Impl {
264270
pending_tasks_.clear();
265271
equal_map_lhs_.clear();
266272
equal_map_rhs_.clear();
273+
root_lhs_ = lhs;
274+
root_rhs_ = rhs;
267275

268276
Optional<ObjectPathPair> current_paths;
269277
if (IsPathTracingEnabled()) {
@@ -313,10 +321,38 @@ class SEqualHandlerDefault::Impl {
313321
*first_mismatch_ = current_paths;
314322
}
315323
if (assert_mode_ && !result) {
316-
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
317-
<< lhs << std::endl
318-
<< "and rhs:" << std::endl
319-
<< rhs;
324+
std::ostringstream oss;
325+
oss << "ValueError: StructuralEqual check failed, caused by lhs";
326+
if (first_mismatch_->defined()) {
327+
oss << " at " << first_mismatch_->value()->lhs_path;
328+
if (root_lhs_.defined()) {
329+
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->lhs_path},
330+
{"syntax_sugar", Bool(false)}};
331+
PrinterConfig cfg(dict);
332+
// The TVMScriptPrinter::Script will fallback to Repr printer,
333+
// if the root node to print is not supported yet,
334+
// e.g. Relay nodes, ArrayNode, MapNode, etc.
335+
oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg);
336+
}
337+
} else {
338+
oss << ":" << std::endl << lhs;
339+
}
340+
oss << std::endl << "and rhs";
341+
if (first_mismatch_->defined()) {
342+
oss << " at " << first_mismatch_->value()->rhs_path;
343+
if (root_rhs_.defined()) {
344+
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->rhs_path},
345+
{"syntax_sugar", Bool(false)}};
346+
PrinterConfig cfg(dict);
347+
// The TVMScriptPrinter::Script will fallback to Repr printer,
348+
// if the root node to print is not supported yet,
349+
// e.g. Relay nodes, ArrayNode, MapNode, etc.
350+
oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg);
351+
}
352+
} else {
353+
oss << ":" << std::endl << rhs;
354+
}
355+
LOG(FATAL) << oss.str();
320356
}
321357
return result;
322358
}
@@ -419,19 +455,27 @@ class SEqualHandlerDefault::Impl {
419455
bool allow_push_to_stack_{true};
420456
// If in assert mode, must return true, and will throw error otherwise.
421457
bool assert_mode_{false};
422-
// Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
458+
// Location to store the paths to the first detected mismatch, or nullptr to disable path
459+
// tracing.
423460
Optional<ObjectPathPair>* first_mismatch_;
424461
// reflection vtable
425462
ReflectionVTable* vtable_ = ReflectionVTable::Global();
426463
// map from lhs to rhs
427464
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
428465
// map from rhs to lhs
429466
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_;
467+
// root lhs for result printing
468+
Optional<ObjectRef> root_lhs_;
469+
// root rhs for result printing
470+
Optional<ObjectRef> root_rhs_;
471+
// whether to defer fails
472+
bool defer_fails_;
430473
};
431474

432475
SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode,
433-
Optional<ObjectPathPair>* first_mismatch) {
434-
impl = new Impl(this, assert_mode, first_mismatch);
476+
Optional<ObjectPathPair>* first_mismatch,
477+
bool defer_fails) {
478+
impl = new Impl(this, assert_mode, first_mismatch, defer_fails);
435479
}
436480

437481
SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; }
@@ -446,6 +490,8 @@ void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) {
446490
impl->DeferFail(mismatch_paths);
447491
}
448492

493+
bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); }
494+
449495
ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); }
450496

451497
void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); }
@@ -463,19 +509,22 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje
463509
TVM_REGISTER_GLOBAL("node.StructuralEqual")
464510
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
465511
bool map_free_vars) {
466-
return SEqualHandlerDefault(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
512+
Optional<ObjectPathPair> first_mismatch;
513+
return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
514+
.Equal(lhs, rhs, map_free_vars);
467515
});
468516

469517
TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
470518
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
471519
Optional<ObjectPathPair> first_mismatch;
472-
bool equal = SEqualHandlerDefault(false, &first_mismatch).Equal(lhs, rhs, map_free_vars);
520+
bool equal =
521+
SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars);
473522
ICHECK(equal == !first_mismatch.defined());
474523
return first_mismatch;
475524
});
476525

477526
bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
478-
return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false);
527+
return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false);
479528
}
480529

481530
bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,

src/node/structural_hash.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -467,16 +467,18 @@ struct ArrayNodeTrait {
467467
// (2) a b c d e g h i j k l m
468468
// ^
469469
// error here
470-
if (lhs->size() > min_size) {
471-
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
472-
array_paths->rhs_path->MissingArrayElement(min_size)});
473-
} else {
474-
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
475-
array_paths->rhs_path->ArrayIndex(min_size)});
470+
if (equal->IsFailDeferralEnabled()) {
471+
if (lhs->size() > min_size) {
472+
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
473+
array_paths->rhs_path->MissingArrayElement(min_size)});
474+
} else {
475+
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
476+
array_paths->rhs_path->ArrayIndex(min_size)});
477+
}
478+
// Can return `true` pretending that everything is good since we have deferred the failure.
479+
return true;
476480
}
477-
478-
// Can return `true` pretending that everything is good since we have deferred the failure.
479-
return true;
481+
return false;
480482
}
481483
};
482484
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)

src/script/printer/doc_printer/python_doc_printer.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,9 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
510510
void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
511511
for (const StmtDoc& stmt : doc->stmts) {
512512
PrintDoc(stmt);
513-
NewLine();
513+
if (stmt != doc->stmts.back()) {
514+
NewLine();
515+
}
514516
}
515517
}
516518

src/script/printer/tir/block.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
4444

4545
std::vector<int> remap_vars_indices;
4646
auto add_remapped_iter_var = [&](int i) -> bool {
47-
if (realize) {
47+
if (realize && d->cfg->syntax_sugar) {
4848
tir::ExprDeepEqual expr_equal;
4949
tir::IterVar iter_var = block->iter_vars[i];
5050
PrimExpr value = realize->iter_values[i];

src/script/printer/tir/for_loop.cc

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
3232
return grid_loop_vars.count(v);
3333
});
3434
};
35-
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
36-
ICHECK(l->loop_var->dtype == l->min->dtype);
37-
ICHECK(l->loop_var->dtype == l->extent->dtype);
38-
if (l->kind != tir::ForKind::kSerial || //
39-
!tir::is_zero(l->min) || //
40-
!l->annotations.empty() || //
41-
f_var_dep(l->extent)) {
42-
break;
35+
if (d->cfg->syntax_sugar) {
36+
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
37+
ICHECK(l->loop_var->dtype == l->min->dtype);
38+
ICHECK(l->loop_var->dtype == l->extent->dtype);
39+
if (l->kind != tir::ForKind::kSerial || //
40+
!tir::is_zero(l->min) || //
41+
!l->annotations.empty() || //
42+
f_var_dep(l->extent)) {
43+
break;
44+
}
45+
grid.push_back(l);
46+
grid_loop_vars.insert(l->loop_var.get());
4347
}
44-
grid.push_back(l);
45-
grid_loop_vars.insert(l->loop_var.get());
4648
}
4749
With<TIRFrame> f(d, loop);
4850
// Step 2. Construct `T.grid`
@@ -114,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
114116
kwargs_values.push_back(annotations.value());
115117
}
116118
ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
117-
AsDocBody(loop->body, loop_p, (*f).get(), d);
119+
AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
118120
return ForDoc(lhs, rhs, (*f)->stmts);
119121
});
120122

0 commit comments

Comments
 (0)