diff --git a/testing/python/language/test_tilelang_language_frontend_v2.py b/testing/python/language/test_tilelang_language_frontend_v2.py index 915574c3e..fb3f1e15a 100644 --- a/testing/python/language/test_tilelang_language_frontend_v2.py +++ b/testing/python/language/test_tilelang_language_frontend_v2.py @@ -342,5 +342,23 @@ def swap_idx(A: T.Tensor[(2,), T.float32]): torch.testing.assert_close(data, ref) +def test_while_loop(): + + @tilelang.jit(out_idx=-1) + @T.prim_func + def test_while_loop(A: T.Tensor((1,), T.int32)): + with T.Kernel(1) as _: + i = T.alloc_var(T.int32, 0) + sum = T.alloc_var(T.int32) + while i < 10: + sum += i + i += 1 + A[0] = sum + + ker = test_while_loop() + A = ker() + assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" + + if __name__ == '__main__': tilelang.testing.main() diff --git a/tilelang/language/tir/ir.pyi b/tilelang/language/tir/ir.pyi new file mode 100644 index 000000000..fe25b58f8 --- /dev/null +++ b/tilelang/language/tir/ir.pyi @@ -0,0 +1,106 @@ +from typing import TypeVar, Literal +from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm + +_T = TypeVar('_T') + +def abs(x: _T, span: Span | None=None) -> _T: ... +def acos(x: _T) -> _T: ... +def acosh(x: _T) -> _T: ... +def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ... +def asin(x: _T) -> _T: ... +def asinh(x: _T) -> _T: ... +def atan(x: _T) -> _T: ... +def atan2(x1: _T, x2: _T) -> _T: ... +def atanh(x: _T) -> _T: ... +def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ... +def bitwise_not(x: _T, span: Span | None=None) -> _T: ... +def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ... +def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ... +def ceil(x: _T, span: Span | None=None) -> _T: ... +def clz(x: _T) -> _T: ... +def copysign(x1: _T, x2: _T) -> _T: ... +def cos(x: _T) -> _T: ... +def cosh(x: _T) -> _T: ... +def erf(x: _T) -> _T: ... +def exp(x: _T) -> _T: ... +def exp2(x: _T) -> _T: ... +def exp10(x: _T) -> _T: ... +def floor(x: _T, span: Span | None=None) -> _T: ... +def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ... +def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ... +def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def fmod(x: _T, y: _T) -> _T: ... +def hypot(x1: _T, x2: _T) -> _T: ... +def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... +def infinity(dtype: _T, span: Span | None=None) -> _T: ... +def isfinite(x: _T, span: Span | None=None) -> _T: ... +def isinf(x: _T, span: Span | None=None) -> _T: ... +def isnan(x: _T, span: Span | None=None) -> _T: ... +def isnullptr(x: _T, span: Span | None=None) -> _T: ... +def ldexp(x1: _T, x2: _T) -> _T: ... +def likely(cond: _T, span: Span | None=None) -> _T: ... +def log(x: _T) -> _T: ... +def log1p(x: _T) -> _T: ... +def log2(x: _T) -> _T: ... +def log10(x: _T) -> _T: ... +def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ... +def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ... +def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ... +def nearbyint(x: _T, span: Span | None=None) -> _T: ... +def nextafter(x1: _T, x2: _T) -> _T: ... +def popcount(x: _T) -> _T: ... +def pow(x: _T, y: _T, span: Span | None=None) -> _T: ... +def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... +def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ... +def ret(val: _T) -> _T: ... +def round(x: _T, span: Span | None=None) -> _T: ... +def rsqrt(x: _T) -> _T: ... +def shift_left(x: _T, y: _T, span=None) -> _T: ... +def shift_right(x: _T, y: _T, span=None) -> _T: ... +def sigmoid(x: _T) -> _T: ... +def sin(x: _T) -> _T: ... +def sinh(x: _T) -> _T: ... +def sqrt(x: _T) -> _T: ... +def tan(x: _T) -> _T: ... +def tanh(x: _T) -> _T: ... +def trunc(x: _T, span: Span | None=None) -> _T: ... +def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... +def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... +def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... +def tvm_throw_last_error() -> _T: ... +def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... +def tvm_stack_make_shape(*args) -> _T: ... +def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... +def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... +def call_packed(*args, span=None) -> _T: ... +def call_cpacked(*args, span=None) -> _T: ... +def call_packed_lowered(*args, span=None) -> _T: ... +def call_cpacked_lowered(*args, span=None) -> _T: ... +def tvm_tuple(*value) -> _T: ... +def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... +def tvm_thread_invariant(cond: _T) -> _T: ... +def tvm_thread_allreduce(*freduce_args) -> _T: ... +def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... +def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... +def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... +def ptx_wait_group(num: int) -> PrimExpr: ... +def ptx_commit_group() -> _T: ... +def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ... +def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... +def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... +def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... +def create_barriers(barrier_count: int) -> PrimExpr: ... +def assume(cond: _T=None) -> _T: ... +def undef() -> _T: ... +def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... +def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... +def start_profile_intrinsic(id: int) -> PrimExpr: ... +def end_profile_intrinsic(id: int) -> PrimExpr: ... +def anylist_getitem(list_handle, index) -> PrimExpr: ... +def anylist_resetitem(list_handle, index) -> PrimExpr: ... +def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ... +def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ... +def vscale() -> _T: ... diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index a8390cfc3..cf879ee59 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -469,6 +469,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign): return self._emit_assign_target(node.target, rval, annot=node.annotation) def visit_While(self, node): + node = self.generic_visit(node) return quote1( "for _ in __tb.ctx_while(lambda: cond):\n pass", cond=node.test, diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 780019c3f..90c8a8e99 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -292,7 +292,22 @@ def ctx_break(self): def ctx_while(self, cond): self.check_continue_break() - raise RuntimeError("while loops are not supported in TileLang builder") + cond_v = cond() + cond_v_unwrap = unwrap_cond(cond_v) + if not isinstance(cond_v_unwrap, PrimExpr): + if cond_v_unwrap: + raise RuntimeError( + f'Infinite while loop detected in TileLang\n' + f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n' + ) + else: + logger.warning( + 'While loop with constant false condition detected in Tilelang, the loop body will never be executed.\n', + f'Condition: {cond_v}({type(cond_v)}) => {cond_v_unwrap}({type(cond_v_unwrap)})\n', + stack_info=True, + stacklevel=2) + with self.with_frame(tir.While(cond_v_unwrap)): + yield None def bind(self, name, value, annot=BaseBuilder.empty): self.check_continue_break()