Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import tilelang.testing
import tvm
from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tir.expr import IntImm, Var


def test_argument():
Expand Down Expand Up @@ -273,6 +275,43 @@ def foo() -> T.Tensor((128,), T.float32):
assert isinstance(foo, T.PrimFunc)


def test_serial_for_with_step():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0

ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i

ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"

assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)


def test_swap_logic():

@tilelang.jit
Expand Down
4 changes: 1 addition & 3 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
)
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
Expand Down
111 changes: 111 additions & 0 deletions tilelang/language/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""The language interface for tl programs."""
from __future__ import annotations

from typing import Any
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from tilelang import _ffi_api


def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member


def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)


def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: list[int] | None = None,
stage: list[int] | None = None,
sync: list[list[int]] | None = None,
group: list[list[int]] | None = None,
):
"""Tools to construct pipelined for loop.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
num_stages : int
The max number of buffer used between pipeline producers and consumers.
if num_stages is 0, pipeline will not be enabled.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
if order is None:
order = []
if stage is None:
stage = []
if sync is None:
sync = []
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)


def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
Comment on lines +97 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing zero-step validation.

The function correctly routes step values to either tb_tir.serial (for step=None or step=1) or SerialForWithStep, and properly normalizes start/stop. However, zero step is not validated before creating SerialForWithStep, which will cause division by zero in builder.py::ctx_for line 255.

Apply this diff to validate step:

 def serial(start: tir.PrimExpr,
            stop: tir.PrimExpr | None = None,
            step: tir.PrimExpr | None = None,
            *,
            annotations: dict[str, Any] | None = None):
+    # Validate non-zero step for constant values
+    if isinstance(step, int) and step == 0:
+        raise ValueError("Serial loop step must not be zero")
+    if isinstance(step, IntImm) and step.value == 0:
+        raise ValueError("Serial loop step must not be zero")
+    
     step_is_one = False
     step_is_one |= isinstance(step, int) and step == 1
     step_is_one |= isinstance(step, IntImm) and step.value == 1

Minor: Consider using logical or instead of bitwise |=.

While lines 103-104 work correctly, the bitwise OR pattern is unconventional for boolean accumulation. Consider:

-    step_is_one = False
-    step_is_one |= isinstance(step, int) and step == 1
-    step_is_one |= isinstance(step, IntImm) and step.value == 1
+    step_is_one = (isinstance(step, int) and step == 1) or \
+                  (isinstance(step, IntImm) and step.value == 1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
# Validate non-zero step for constant values
if isinstance(step, int) and step == 0:
raise ValueError("Serial loop step must not be zero")
if isinstance(step, IntImm) and step.value == 0:
raise ValueError("Serial loop step must not be zero")
step_is_one = (isinstance(step, int) and step == 1) or \
(isinstance(step, IntImm) and step.value == 1)
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
🤖 Prompt for AI Agents
In tilelang/language/loop.py around lines 97 to 111, the function lacks
validation for a zero step which will cause a division-by-zero downstream; add a
guard before creating SerialForWithStep that checks if step represents zero
(e.g., isinstance(step, int) and step == 0 or isinstance(step, IntImm) and
step.value == 0) and raise a ValueError with a clear message if so; also replace
the bitwise accumulation using '|=' on step_is_one with conventional boolean
logic using 'or' (e.g., step_is_one = (isinstance(step, int) and step == 1) or
(isinstance(step, IntImm) and step.value == 1)) so the intent is clearer.

29 changes: 0 additions & 29 deletions tilelang/language/parallel.py

This file was deleted.

27 changes: 0 additions & 27 deletions tilelang/language/persistent.py

This file was deleted.

46 changes: 0 additions & 46 deletions tilelang/language/pipeline.py

This file was deleted.

38 changes: 31 additions & 7 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class BreakFrame(Frame):
...


@dataclass
class SerialForWithStep:
start: PrimExpr
stop: PrimExpr
step: PrimExpr
annotations: dict[str, Any] | None = None


# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
Expand Down Expand Up @@ -243,12 +251,27 @@ def eval(self, val: Any):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
if isinstance(it, SerialForWithStep):
real_stop = tir.ceildiv(it.stop - it.start, it.step)
if isinstance(it.step, (int, IntImm)):
value = it.step if isinstance(it.step, int) else it.step.value
if value < 0:
real_stop = tir.ceildiv(it.start - it.stop, -it.step)
else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_frame = tir.serial(real_stop, annotations=it.annotations)
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v

def ctx_continue(self):
self.check_continue_break()
Expand Down Expand Up @@ -459,8 +482,9 @@ def arg(self, name, value):
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")

def override(self, name: str):
import tilelang.language as T
if name == 'range':
return tir.serial
return T.serial
raise ValueError(f'Unknown override: {name}')


Expand Down
Loading