Skip to content

Commit 657ebbb

Browse files
[TVMScript] Support continue and break in tvmscript (#17804)
* support continue and break in tvmscript * fix black format * fix pylint issue * Update tests/python/tvmscript/test_tvmscript_syntax_sugar.py Co-authored-by: Copilot <[email protected]> * add printer/parser test, fix lint * Fit to latest ffi update * Skip i386 numpy-related test * Introduce AnnotateIrregularLoop before any lowering loop expansions. --------- Co-authored-by: Copilot <[email protected]>
1 parent 53356be commit 657ebbb

File tree

24 files changed

+657
-15
lines changed

24 files changed

+657
-15
lines changed

include/tvm/tir/builtin.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ TVM_DLL const Op& ret();
4949
* \brief Return from a GPU thread.
5050
*/
5151
TVM_DLL const Op& thread_return();
52+
/*!
53+
* \brief Loop continue.
54+
*/
55+
TVM_DLL const Op& continue_loop();
56+
/*!
57+
* \brief Loop break.
58+
*/
59+
TVM_DLL const Op& break_loop();
5260
/*!
5361
* \brief Reinterpret the value using the target type.
5462
*/

include/tvm/tir/op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span());
9999
*/
100100
TVM_DLL PrimExpr thread_return(Span span = Span());
101101

102+
/*!
103+
* \brief Continue current loop.
104+
* \param span The location of this operation in the source.
105+
* \return The continue loop expression.
106+
*/
107+
TVM_DLL PrimExpr continue_loop(Span span = Span());
108+
109+
/*!
110+
* \brief Break current loop.
111+
* \param span The location of this operation in the source.
112+
* \return The break loop expression.
113+
*/
114+
TVM_DLL PrimExpr break_loop(Span span = Span());
115+
102116
/*!
103117
* Query the maximum possible value of dtype.
104118
* \param dtype The data type.

include/tvm/tir/stmt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,9 @@ constexpr const char* explicit_read_region = "explicit_read_region";
13101310
*/
13111311
constexpr const char* explicit_write_region = "explicit_write_region";
13121312

1313+
/*! \brief ,ark a ForNode represent an irregular loop of non-structural control flow edges. */
1314+
constexpr const char* irregular_loop_mark = "irregular_loop_mark";
1315+
13131316
/*!
13141317
* \brief Check if attr_key is a pragma key extension
13151318
* \param attr_key The attr key to be compared

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,6 +1917,8 @@ def wrapped(*args, **kwargs):
19171917
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
19181918
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
19191919
ret = _op_wrapper(_tir_op.ret)
1920+
continue_loop = _op_wrapper(_tir_op.continue_loop)
1921+
break_loop = _op_wrapper(_tir_op.break_loop)
19201922
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
19211923
rsqrt = _op_wrapper(_tir_op.rsqrt)
19221924
shift_left = _op_wrapper(_tir_op.shift_left)
@@ -2195,6 +2197,8 @@ def wrapped(*args, **kwargs):
21952197
"q_multiply_shift",
21962198
"q_multiply_shift_per_axis",
21972199
"ret",
2200+
"continue_loop",
2201+
"break_loop",
21982202
"reinterpret",
21992203
"round",
22002204
"rsqrt",

python/tvm/script/parser/core/parser.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,36 @@ def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name
872872
"""
873873
return _dispatch(self, "Return")(self, node)
874874

875+
def visit_Continue(self, node: doc.Continue) -> Any: # pylint: disable=invalid-name
876+
"""The general continue visiting method.
877+
878+
Parameters
879+
----------
880+
node : doc.Continue
881+
The doc AST continue node.
882+
883+
Returns
884+
-------
885+
res : Any
886+
The visiting result.
887+
"""
888+
return _dispatch(self, "Continue")(self, node)
889+
890+
def visit_Break(self, node: doc.Break) -> Any: # pylint: disable=invalid-name
891+
"""The general break visiting method.
892+
893+
Parameters
894+
----------
895+
node : doc.Break
896+
The doc AST break node.
897+
898+
Returns
899+
-------
900+
res : Any
901+
The visiting result.
902+
"""
903+
return _dispatch(self, "Break")(self, node)
904+
875905
def visit_Nonlocal(self, node: doc.Nonlocal) -> Any: # pylint: disable=invalid-name
876906
"""The general nonlocal visiting method.
877907

python/tvm/script/parser/tir/parser.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,8 @@ def visit_with(self: Parser, node: doc.With) -> None:
353353
frame = self.eval_expr(item.context_expr)
354354
if not isinstance(frame, Frame):
355355
self.report_error(
356-
item.context_expr, "Invalid context expression in the with-statement."
356+
item.context_expr,
357+
"Invalid context expression in the with-statement.",
357358
)
358359
rhs = stack.enter_context(frame)
359360
if item.optional_vars is not None:
@@ -498,7 +499,8 @@ def visit_if(self: Parser, node: doc.If) -> None:
498499
self.visit_body(node.orelse)
499500
else:
500501
self.report_error(
501-
node.test, f"If condition must be a boolean expression, but got {predicate}"
502+
node.test,
503+
f"If condition must be a boolean expression, but got {predicate}",
502504
)
503505

504506

@@ -539,6 +541,36 @@ def visit_return(self: Parser, node: doc.Return) -> None:
539541
T.evaluate(tvm.tir.ret(value))
540542

541543

544+
@dispatch.register(token="tir", type_name="Continue")
545+
def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable=unused-argument
546+
"""The continue visiting method for tir.
547+
548+
Parameters
549+
----------
550+
self : Parser
551+
The visiting parser.
552+
553+
node : doc.Continue
554+
The doc AST continue node.
555+
"""
556+
T.evaluate(tvm.tir.continue_loop())
557+
558+
559+
@dispatch.register(token="tir", type_name="Break")
560+
def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused-argument
561+
"""The continue visiting method for tir.
562+
563+
Parameters
564+
----------
565+
self : Parser
566+
The visiting parser.
567+
568+
node : doc.Break
569+
The doc AST break node.
570+
"""
571+
T.evaluate(tvm.tir.break_loop())
572+
573+
542574
@dispatch.register(token="tir", type_name="tvm_declare_function")
543575
def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar:
544576
"""The function declaration step for tir

python/tvm/tir/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@
5050
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
5151
from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set
5252
from .op import address_of, lookup_param, assume, undef
53-
from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
53+
from .op import continue_loop, break_loop
54+
from .op import (
55+
tvm_thread_allreduce,
56+
type_annotation,
57+
tvm_access_ptr,
58+
tvm_throw_last_error,
59+
)
5460
from .op import (
5561
tvm_load_matrix_sync,
5662
tvm_store_matrix_sync,
@@ -86,7 +92,18 @@
8692
from .op import tan, tanh, atan, atan2, atanh
8793
from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor
8894
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
89-
from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else
95+
from .op import (
96+
trunc,
97+
abs,
98+
round,
99+
nextafter,
100+
nearbyint,
101+
power,
102+
pow,
103+
popcount,
104+
fmod,
105+
if_then_else,
106+
)
90107
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
91108
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv, logaddexp
92109
from .op import comm_reducer, min, max, sum

python/tvm/tir/op.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,8 +1884,7 @@ def ret(val, span=None):
18841884

18851885

18861886
def thread_return(span=None):
1887-
"""Return from a GPU thread.
1888-
1887+
"""Return from a GPU thread
18891888
Parameters
18901889
----------
18911890
span : Optional[Span]
@@ -1900,6 +1899,40 @@ def thread_return(span=None):
19001899
return _ffi_api.thread_return(span)
19011900

19021901

1902+
def continue_loop(span=None):
1903+
"""Create a tir intrinsic call to represent continue expression
1904+
1905+
Parameters
1906+
----------
1907+
span : Optional[Span]
1908+
The location of this operator in the source code.
1909+
1910+
Returns
1911+
-------
1912+
ret : PrimExpr
1913+
The continue expression
1914+
"""
1915+
1916+
return _ffi_api.continue_loop(span)
1917+
1918+
1919+
def break_loop(span=None):
1920+
"""Create a tir intrinsic call to represent break expression
1921+
1922+
Parameters
1923+
----------
1924+
span : Optional[Span]
1925+
The location of this operator in the source code.
1926+
1927+
Returns
1928+
-------
1929+
ret : PrimExpr
1930+
The break expression
1931+
"""
1932+
1933+
return _ffi_api.break_loop(span)
1934+
1935+
19031936
def any(*args, span=None):
19041937
"""Create a new experssion of the union of all conditions in the arguments
19051938

python/tvm/tir/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
4343
tir.transform.LowerMatchBuffer(),
4444
tir.transform.Simplify(),
4545
tir.transform.InjectPermutedLayout(),
46+
tir.transform.AnnotateIrregularLoop(),
4647
tir.transform.InjectSoftwarePipeline(),
4748
tir.transform.TransformMmaBufferLayout(),
4849
tir.transform.LowerOpaqueBlock(),

python/tvm/tir/transform/transform.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,19 @@ def AnnotateDeviceRegions():
430430
return _ffi_api.AnnotateDeviceRegions() # type: ignore
431431

432432

433+
def AnnotateIrregularLoop():
434+
"""Annotate irregular loop mark. Loop transformations like
435+
peeling, partition, unroll, etc is not allowed on irregular
436+
loop with internal loop continuation and breaks.
437+
438+
Returns
439+
-------
440+
fpass : tvm.transform.Pass
441+
The result pass
442+
"""
443+
return _ffi_api.AnnotateIrregularLoop() # type: ignore
444+
445+
433446
def SplitHostDevice():
434447
"""Split the function into a host function and device functions.
435448

0 commit comments

Comments
 (0)