Skip to content

Commit 556ec6b

Browse files
author
Joey Tsai
committed
[SpanFillingCommonAPI]
- Change based on comment - Discard the change of pretty-print - Add document to set_span
1 parent 8883a1a commit 556ec6b

File tree

3 files changed

+54
-36
lines changed

3 files changed

+54
-36
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,9 @@ def new_const(self, value, shape=None, dtype="float32", source_name=None):
313313
shape = value.shape
314314
self.const_ctr += 1
315315
self.params[name] = value
316-
tmp_var = _expr.var(name_hint=name, shape=shape, dtype=dtype)
316+
self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
317317
if source_name:
318-
tmp_var = set_span(tmp_var, source_name)
319-
self.exprs[name] = tmp_var
318+
self.exprs[name] = set_span(self.exprs[name], source_name)
320319
return self.exprs[name]
321320

322321
def get_expr(self, name):
@@ -1130,7 +1129,39 @@ def _should_fill_span():
11301129

11311130

11321131
def set_span(sym, span):
1133-
"""Set up the sapn of relay expression(s) while converting OP"""
1132+
"""
1133+
Recursively tag the span to the symbol. Stop when it encounters a span-tagged expr. Disabled
1134+
when setting the environment variable "TVM_SPANFILLING" as 0.
1135+
1136+
Parameters
1137+
----------
1138+
sym :
1139+
A symbol is generated from the conversion of a frontend operator. Raise an error when the
1140+
type of the symbol is not supported.
1141+
1142+
span : String, Span, or bytes
1143+
The source information of the corresponding symbol.
1144+
1145+
Returns
1146+
-------
1147+
result :
1148+
The symbol tagged with span.
1149+
1150+
Examples
1151+
--------
1152+
.. code-block:: python
1153+
1154+
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
1155+
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
1156+
y = set_span(
1157+
relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)), "conv2d"
1158+
)
1159+
print(relay.Function([x], y))
1160+
1161+
#fn (%x: Tensor[(1, 64, 56, 56), float32] /* span=x_var:0:0 */) {
1162+
# nn.conv2d(%x, meta[relay.Constant][0] /* span=conv2d:0:0 */, ...) /* span=conv2d:0:0 */
1163+
#}
1164+
"""
11341165

11351166
if _should_fill_span():
11361167
return _SpanFiller(span).fill(sym)

src/printer/relay_text_printer.cc

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,22 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
6565
if (annotate_ == nullptr) {
6666
if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() ||
6767
expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) &&
68-
expr->checked_type_.defined()) {
69-
doc << " /* ty=" << Print(expr->checked_type()) << " */";
68+
(expr->checked_type_.defined() || expr->span.defined())) {
69+
doc << " /*";
70+
if (expr->checked_type_.defined()) {
71+
doc << " ty=" << Print(expr->checked_type());
72+
}
73+
if (expr->span.defined()) {
74+
doc << " span=" << PrintSpan(expr->span);
75+
}
76+
doc << " */";
7077
}
7178
} else {
7279
std::string annotated_expr = annotate_(expr);
7380
if (annotated_expr != "") {
7481
doc << annotated_expr;
7582
}
7683
}
77-
78-
if (expr->span.defined()) {
79-
doc << " /* si=" << Print(expr->span) << " */";
80-
}
81-
8284
return doc;
8385
}
8486

@@ -130,10 +132,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
130132
return PrintPattern(Downcast<Pattern>(node), meta);
131133
} else if (node.as<IRModuleNode>()) {
132134
return PrintMod(Downcast<IRModule>(node));
133-
} else if (node.as<SpanNode>()) {
134-
std::ostringstream os;
135-
os << Downcast<Span>(node);
136-
return Doc::RawText(os.str());
137135
} else {
138136
// default module.
139137
std::ostringstream os;
@@ -964,6 +962,14 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>&
964962
return doc;
965963
}
966964

965+
Doc RelayTextPrinter::PrintSpan(const Span& span) {
966+
Doc doc;
967+
const auto* span_node = span.as<SpanNode>();
968+
ICHECK(span_node);
969+
doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column;
970+
return doc;
971+
}
972+
967973
TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
968974
auto text = AsText(node, false, nullptr);
969975
return text;

src/printer/text_printer.h

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
113113
*/
114114
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
115115

116+
Doc PrintSpan(const Span& span);
117+
116118
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
117119

118120
Doc TempVar(int n);
@@ -470,25 +472,4 @@ class TextPrinter {
470472
};
471473
} // namespace tvm
472474

473-
namespace tvm {
474-
namespace runtime {
475-
476-
inline std::ostream& operator<<(std::ostream& os, const SourceName& source_name) { // NOLINT(*)
477-
ICHECK(source_name->name.defined());
478-
os << source_name->name;
479-
return os;
480-
}
481-
482-
inline std::ostream& operator<<(std::ostream& os, const Span& span) { // NOLINT(*)
483-
if (span.defined()) {
484-
os << span->source_name;
485-
} else {
486-
os << "nullptr";
487-
}
488-
return os;
489-
}
490-
491-
} // namespace runtime
492-
} // namespace tvm
493-
494475
#endif // TVM_PRINTER_TEXT_PRINTER_H_

0 commit comments

Comments
 (0)