Skip to content

Commit 3b3443b

Browse files
authored
[TIR][Schedule][UX] Beautify TIR Trace Printing (#12507)
Following #12197, this PR introduces `Schedule.show()` which convenience the user experience in the following two aspects: - Python syntax highlighting - Outputs a schedule function instead of standalone instructions so that it's easier to follow. To demonstrate this change: - Before `Schedule.show()` is introduced: <img width="555" alt="image" src="https://user-images.githubusercontent.com/22515877/185713487-03722566-1df7-45c7-a034-c1460d399681.png"> - After this change: <img width="583" alt="image" src="https://user-images.githubusercontent.com/22515877/185713564-c54f3a9d-cd52-4709-a8b8-d8a61361e611.png">
1 parent eb31123 commit 3b3443b

File tree

6 files changed

+118
-56
lines changed

6 files changed

+118
-56
lines changed

python/tvm/meta_schedule/testing/space_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def check_trace(spaces: List[Schedule], expected: List[List[str]]):
2929
for space in spaces:
3030
trace = Trace(space.trace.insts, {})
3131
trace = trace.simplified(remove_postproc=True)
32-
str_trace = "\n".join(str(trace).strip().splitlines())
32+
str_trace = "\n".join(t[2:] for t in str(trace).strip().splitlines()[2:] if t != " pass")
3333
actual_traces.add(str_trace)
3434
assert str_trace in expected_traces, "\n" + str_trace
3535
assert len(expected_traces) == len(actual_traces)

python/tvm/script/highlight.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@
1717
"""Highlight printed TVM script.
1818
"""
1919

20-
from typing import Union, Optional
21-
import warnings
2220
import sys
21+
import warnings
22+
from typing import Optional, Union
2323

2424
from tvm.ir import IRModule
2525
from tvm.tir import PrimFunc
2626

2727

28-
def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) -> None:
28+
def cprint(printable: Union[IRModule, PrimFunc, str], style: Optional[str] = None) -> None:
2929
"""
3030
Print highlighted TVM script string with Pygments
3131
Parameters
3232
----------
33-
printable : Union[IRModule, PrimFunc]
33+
printable : Union[IRModule, PrimFunc, str]
3434
The TVM script to be printed
3535
style : str, optional
3636
Printing style, auto-detected if None.
@@ -44,16 +44,17 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) ->
4444
installing the Pygment library. Other Pygment styles can be found in
4545
https://pygments.org/styles/
4646
"""
47-
47+
if isinstance(printable, (IRModule, PrimFunc)):
48+
printable = printable.script()
4849
try:
4950
# pylint: disable=import-outside-toplevel
5051
import pygments
52+
from packaging import version
5153
from pygments import highlight
54+
from pygments.formatters import HtmlFormatter, Terminal256Formatter
5255
from pygments.lexers.python import Python3Lexer
53-
from pygments.formatters import Terminal256Formatter, HtmlFormatter
5456
from pygments.style import Style
55-
from pygments.token import Keyword, Name, Comment, String, Number, Operator
56-
from packaging import version
57+
from pygments.token import Comment, Keyword, Name, Number, Operator, String
5758

5859
if version.parse(pygments.__version__) < version.parse("2.4.0"):
5960
raise ImportError("Required Pygments version >= 2.4.0 but got " + pygments.__version__)
@@ -68,7 +69,7 @@ def cprint(printable: Union[IRModule, PrimFunc], style: Optional[str] = None) ->
6869
+ install_cmd,
6970
category=UserWarning,
7071
)
71-
print(printable.script())
72+
print(printable)
7273
else:
7374

7475
class JupyterLight(Style):
@@ -136,11 +137,14 @@ class AnsiTerminalDefault(Style):
136137
style = AnsiTerminalDefault
137138

138139
if is_in_notebook: # print with HTML display
139-
from IPython.display import display, HTML # pylint: disable=import-outside-toplevel
140+
from IPython.display import ( # pylint: disable=import-outside-toplevel
141+
HTML,
142+
display,
143+
)
140144

141145
formatter = HtmlFormatter(style=JupyterLight)
142146
formatter.noclasses = True # inline styles
143-
html = highlight(printable.script(), Python3Lexer(), formatter)
147+
html = highlight(printable, Python3Lexer(), formatter)
144148
display(HTML(html))
145149
else:
146-
print(highlight(printable.script(), Python3Lexer(), Terminal256Formatter(style=style)))
150+
print(highlight(printable, Python3Lexer(), Terminal256Formatter(style=style)))

python/tvm/tir/schedule/trace.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,17 @@ def apply_json_to_schedule(json_obj: JSON_TYPE, sch: "Schedule") -> None:
258258
The TensorIR schedule
259259
"""
260260
_ffi_api.TraceApplyJSONToSchedule(json_obj, sch) # type: ignore # pylint: disable=no-member
261+
262+
def show(self, style: Optional[str] = None) -> None:
263+
"""A sugar for print highlighted trace.
264+
265+
Parameters
266+
----------
267+
style : str, optional
268+
Pygments styles extended by "light" (default) and "dark", by default "light"
269+
"""
270+
from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel
271+
cprint,
272+
)
273+
274+
cprint(str(self), style=style)

src/tir/schedule/trace.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,22 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
476476
.set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) {
477477
const auto* self = obj.as<TraceNode>();
478478
ICHECK_NOTNULL(self);
479+
p->stream << "# from tvm import tir\n";
480+
p->stream << "def apply_trace(sch: tir.Schedule) -> None:\n";
479481
Array<String> repr = self->AsPython(/*remove_postproc=*/false);
480482
bool is_first = true;
481483
for (const String& line : repr) {
482484
if (is_first) {
483485
is_first = false;
484486
} else {
485-
p->stream << std::endl;
487+
p->stream << '\n';
486488
}
487-
p->stream << line;
489+
p->stream << " " << line;
488490
}
491+
if (is_first) {
492+
p->stream << " pass";
493+
}
494+
p->stream << std::flush;
489495
});
490496

491497
/**************** Instruction Registration ****************/

tests/python/unittest/test_meta_schedule_post_order_apply.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,18 +322,22 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
322322
def correct_trace(a, b, c, d):
323323
return "\n".join(
324324
[
325-
'b0 = sch.get_block(name="A", func_name="main")',
326-
'b1 = sch.get_block(name="B", func_name="main")',
327-
'b2 = sch.get_block(name="C", func_name="main")',
328-
"sch.compute_inline(block=b1)",
329-
"l3, l4 = sch.get_loops(block=b2)",
330-
"l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
331-
"l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
332-
"sch.reorder(l5, l7, l6, l8)",
333-
"l9, l10 = sch.get_loops(block=b0)",
334-
"l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
335-
"l13, l14 = sch.split(loop=l10, factors=" + str(d) + ", preserve_unit_iters=True)",
336-
"sch.reorder(l11, l13, l12, l14)",
325+
"# from tvm import tir",
326+
"def apply_trace(sch: tir.Schedule) -> None:",
327+
' b0 = sch.get_block(name="A", func_name="main")',
328+
' b1 = sch.get_block(name="B", func_name="main")',
329+
' b2 = sch.get_block(name="C", func_name="main")',
330+
" sch.compute_inline(block=b1)",
331+
" l3, l4 = sch.get_loops(block=b2)",
332+
" l5, l6 = sch.split(loop=l3, factors=" + str(a) + ", preserve_unit_iters=True)",
333+
" l7, l8 = sch.split(loop=l4, factors=" + str(b) + ", preserve_unit_iters=True)",
334+
" sch.reorder(l5, l7, l6, l8)",
335+
" l9, l10 = sch.get_loops(block=b0)",
336+
" l11, l12 = sch.split(loop=l9, factors=" + str(c) + ", preserve_unit_iters=True)",
337+
" l13, l14 = sch.split(loop=l10, factors="
338+
+ str(d)
339+
+ ", preserve_unit_iters=True)",
340+
" sch.reorder(l11, l13, l12, l14)",
337341
]
338342
)
339343

tests/python/unittest/test_tir_schedule_trace.py

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ def test_trace_construct_1():
163163
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
164164
assert str(trace) == "\n".join(
165165
(
166-
'b0 = sch.get_block(name="block", func_name="main")',
167-
"l1, l2 = sch.get_loops(block=b0)",
166+
"# from tvm import tir",
167+
"def apply_trace(sch: tir.Schedule) -> None:",
168+
' b0 = sch.get_block(name="block", func_name="main")',
169+
" l1, l2 = sch.get_loops(block=b0)",
168170
)
169171
)
170172
assert len(trace.insts) == 2
@@ -182,9 +184,11 @@ def test_trace_construct_append_1():
182184
trace.append(inst=_make_get_block("block2", BlockRV()))
183185
assert str(trace) == "\n".join(
184186
(
185-
'b0 = sch.get_block(name="block", func_name="main")',
186-
"l1, l2 = sch.get_loops(block=b0)",
187-
'b3 = sch.get_block(name="block2", func_name="main")',
187+
"# from tvm import tir",
188+
"def apply_trace(sch: tir.Schedule) -> None:",
189+
' b0 = sch.get_block(name="block", func_name="main")',
190+
" l1, l2 = sch.get_loops(block=b0)",
191+
' b3 = sch.get_block(name="block2", func_name="main")',
188192
)
189193
)
190194

@@ -193,14 +197,32 @@ def test_trace_construct_pop_1():
193197
trace = _make_trace_1(BlockRV(), LoopRV(), LoopRV())
194198
last_inst = trace.insts[-1]
195199
assert trace.pop().same_as(last_inst)
196-
assert str(trace) == 'b0 = sch.get_block(name="block", func_name="main")'
200+
assert str(trace) == "\n".join(
201+
(
202+
"# from tvm import tir",
203+
"def apply_trace(sch: tir.Schedule) -> None:",
204+
' b0 = sch.get_block(name="block", func_name="main")',
205+
)
206+
)
197207

198208

199209
def test_trace_construct_pop_2():
200210
trace = Trace([], {})
201-
assert str(trace) == ""
211+
assert str(trace) == "\n".join(
212+
(
213+
"# from tvm import tir",
214+
"def apply_trace(sch: tir.Schedule) -> None:",
215+
" pass",
216+
)
217+
)
202218
assert trace.pop() is None
203-
assert str(trace) == ""
219+
assert str(trace) == "\n".join(
220+
(
221+
"# from tvm import tir",
222+
"def apply_trace(sch: tir.Schedule) -> None:",
223+
" pass",
224+
)
225+
)
204226

205227

206228
def test_trace_apply_to_schedule():
@@ -226,18 +248,22 @@ def test_trace_simplified_1():
226248
trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True)
227249
assert str(trace) == "\n".join(
228250
(
229-
'b0 = sch.get_block(name="B", func_name="main")',
230-
"sch.compute_inline(block=b0)",
231-
'b1 = sch.get_block(name="C", func_name="main")',
232-
"sch.enter_postproc()",
233-
"sch.compute_inline(block=b1)",
251+
"# from tvm import tir",
252+
"def apply_trace(sch: tir.Schedule) -> None:",
253+
' b0 = sch.get_block(name="B", func_name="main")',
254+
" sch.compute_inline(block=b0)",
255+
' b1 = sch.get_block(name="C", func_name="main")',
256+
" sch.enter_postproc()",
257+
" sch.compute_inline(block=b1)",
234258
)
235259
)
236260
trace = trace.simplified(remove_postproc=True)
237261
assert str(trace) == "\n".join(
238262
(
239-
'b0 = sch.get_block(name="B", func_name="main")',
240-
"sch.compute_inline(block=b0)",
263+
"# from tvm import tir",
264+
"def apply_trace(sch: tir.Schedule) -> None:",
265+
' b0 = sch.get_block(name="B", func_name="main")',
266+
" sch.compute_inline(block=b0)",
241267
)
242268
)
243269

@@ -246,21 +272,26 @@ def test_trace_simplified_2():
246272
trace = _make_trace_3(BlockRV(), BlockRV(), add_postproc=True)
247273
assert str(trace) == "\n".join(
248274
(
249-
'b0 = sch.get_block(name="B", func_name="main")',
250-
"sch.compute_inline(block=b0)",
251-
'b1 = sch.get_block(name="C", func_name="main")',
252-
"sch.enter_postproc()",
253-
"sch.compute_inline(block=b1)",
275+
"# from tvm import tir",
276+
"def apply_trace(sch: tir.Schedule) -> None:",
277+
' b0 = sch.get_block(name="B", func_name="main")',
278+
" sch.compute_inline(block=b0)",
279+
' b1 = sch.get_block(name="C", func_name="main")',
280+
" sch.enter_postproc()",
281+
" sch.compute_inline(block=b1)",
254282
)
255283
)
256284
trace = trace.simplified(remove_postproc=False)
285+
print(trace.show())
257286
assert str(trace) == "\n".join(
258287
(
259-
'b0 = sch.get_block(name="B", func_name="main")',
260-
"sch.compute_inline(block=b0)",
261-
'b1 = sch.get_block(name="C", func_name="main")',
262-
"sch.enter_postproc()",
263-
"sch.compute_inline(block=b1)",
288+
"# from tvm import tir",
289+
"def apply_trace(sch: tir.Schedule) -> None:",
290+
' b0 = sch.get_block(name="B", func_name="main")',
291+
" sch.compute_inline(block=b0)",
292+
' b1 = sch.get_block(name="C", func_name="main")',
293+
" sch.enter_postproc()",
294+
" sch.compute_inline(block=b1)",
264295
)
265296
)
266297

@@ -269,9 +300,11 @@ def test_trace_simplified_3():
269300
trace = _make_trace_4(BlockRV(), LoopRV(), LoopRV(), LoopRV()).simplified(remove_postproc=False)
270301
assert str(trace) == "\n".join(
271302
(
272-
'b0 = sch.get_block(name="B", func_name="main")',
273-
"l1, = sch.get_loops(block=b0)",
274-
"l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
303+
"# from tvm import tir",
304+
"def apply_trace(sch: tir.Schedule) -> None:",
305+
' b0 = sch.get_block(name="B", func_name="main")',
306+
" l1, = sch.get_loops(block=b0)",
307+
" l2, l3 = sch.split(loop=l1, factors=[None, 32], preserve_unit_iters=True)",
275308
)
276309
)
277310

@@ -335,4 +368,5 @@ def test_apply_annotation_from_json():
335368

336369

337370
if __name__ == "__main__":
338-
tvm.testing.main()
371+
test_trace_simplified_2()
372+
# tvm.testing.main()

0 commit comments

Comments
 (0)