Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import tilelang.testing
import tvm
from tvm.script.ir_builder.base import IRBuilderFrame
from tvm.tir.expr import IntImm, Var


def test_argument():
Expand Down Expand Up @@ -273,6 +275,43 @@ def foo() -> T.Tensor((128,), T.float32):
assert isinstance(foo, T.PrimFunc)


def test_serial_for_with_step():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0

ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i

ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"

assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm('int32', 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var('tmp', 'int32')), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)


def test_swap_logic():

@tilelang.jit
Expand Down
4 changes: 1 addition & 3 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
)
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
Expand Down
111 changes: 111 additions & 0 deletions tilelang/language/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""The language interface for tl programs."""
from __future__ import annotations

from typing import Any
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from tilelang import _ffi_api


def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member


def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)


def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: list[int] | None = None,
stage: list[int] | None = None,
sync: list[list[int]] | None = None,
group: list[list[int]] | None = None,
):
"""Tools to construct pipelined for loop.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
num_stages : int
The max number of buffer used between pipeline producers and consumers.
if num_stages is 0, pipeline will not be enabled.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
if order is None:
order = []
if stage is None:
stage = []
if sync is None:
sync = []
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)


def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
Comment on lines +97 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing zero-step validation.

The function correctly routes step values to either tb_tir.serial (for step=None or step=1) or SerialForWithStep, and properly normalizes start/stop. However, zero step is not validated before creating SerialForWithStep, which will cause division by zero in builder.py::ctx_for line 255.

Apply this diff to validate step:

 def serial(start: tir.PrimExpr,
            stop: tir.PrimExpr | None = None,
            step: tir.PrimExpr | None = None,
            *,
            annotations: dict[str, Any] | None = None):
+    # Validate non-zero step for constant values
+    if isinstance(step, int) and step == 0:
+        raise ValueError("Serial loop step must not be zero")
+    if isinstance(step, IntImm) and step.value == 0:
+        raise ValueError("Serial loop step must not be zero")
+    
     step_is_one = False
     step_is_one |= isinstance(step, int) and step == 1
     step_is_one |= isinstance(step, IntImm) and step.value == 1

Minor: Consider using logical or instead of bitwise |=.

While lines 103-104 work correctly, the bitwise OR pattern is unconventional for boolean accumulation. Consider:

-    step_is_one = False
-    step_is_one |= isinstance(step, int) and step == 1
-    step_is_one |= isinstance(step, IntImm) and step.value == 1
+    step_is_one = (isinstance(step, int) and step == 1) or \
+                  (isinstance(step, IntImm) and step.value == 1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
# Validate non-zero step for constant values
if isinstance(step, int) and step == 0:
raise ValueError("Serial loop step must not be zero")
if isinstance(step, IntImm) and step.value == 0:
raise ValueError("Serial loop step must not be zero")
step_is_one = (isinstance(step, int) and step == 1) or \
(isinstance(step, IntImm) and step.value == 1)
if step is None or step_is_one:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
🤖 Prompt for AI Agents
In tilelang/language/loop.py around lines 97 to 111, the function lacks
validation for a zero step which will cause a division-by-zero downstream; add a
guard before creating SerialForWithStep that checks if step represents zero
(e.g., isinstance(step, int) and step == 0 or isinstance(step, IntImm) and
step.value == 0) and raise a ValueError with a clear message if so; also replace
the bitwise accumulation using '|=' on step_is_one with conventional boolean
logic using 'or' (e.g., step_is_one = (isinstance(step, int) and step == 1) or
(isinstance(step, IntImm) and step.value == 1)) so the intent is clearer.

29 changes: 0 additions & 29 deletions tilelang/language/parallel.py

This file was deleted.

27 changes: 0 additions & 27 deletions tilelang/language/persistent.py

This file was deleted.

46 changes: 0 additions & 46 deletions tilelang/language/pipeline.py

This file was deleted.

43 changes: 36 additions & 7 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ class BreakFrame(Frame):
...


@dataclass
class SerialForWithStep:
start: PrimExpr
stop: PrimExpr
step: PrimExpr
annotations: dict[str, Any] | None = None


# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
Expand Down Expand Up @@ -243,12 +251,32 @@ def eval(self, val: Any):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
if isinstance(it, SerialForWithStep):
# Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
if step_value == 0:
raise ValueError('Invalid stepped serial: step must be non-zero')
if step_value > 0:
real_stop = tir.ceildiv(it.stop - it.start, step_value)
else:
real_stop = tir.ceildiv(it.start - it.stop, -step_value)
Comment on lines +256 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Handle empty iteration range for negative steps.

When step < 0 and start <= stop (e.g., T.serial(0, 10, -3)), line 263 computes ceildiv(0 - 10, 3) = -3, producing a negative real_stop. This causes invalid IR generation when passed to tir.serial at line 269.

Python's range(0, 10, -3) produces an empty sequence; the DSL should yield 0 iterations in this case.

Apply this diff to clamp real_stop to non-negative values:

                if step_value > 0:
-                   real_stop = tir.ceildiv(it.stop - it.start, step_value)
+                   real_stop = tir.max(tir.ceildiv(it.stop - it.start, step_value), 0)
                else:
-                   real_stop = tir.ceildiv(it.start - it.stop, -step_value)
+                   real_stop = tir.max(tir.ceildiv(it.start - it.stop, -step_value), 0)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
if step_value == 0:
raise ValueError('Invalid stepped serial: step must be non-zero')
if step_value > 0:
real_stop = tir.ceildiv(it.stop - it.start, step_value)
else:
real_stop = tir.ceildiv(it.start - it.stop, -step_value)
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
if step_value == 0:
raise ValueError('Invalid stepped serial: step must be non-zero')
if step_value > 0:
real_stop = tir.max(tir.ceildiv(it.stop - it.start, step_value), 0)
else:
real_stop = tir.max(tir.ceildiv(it.start - it.stop, -step_value), 0)
🧰 Tools
🪛 Ruff (0.14.3)

259-259: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 256 to 263, the computation of
real_stop for negative step values can produce a negative count (e.g., start <=
stop with step < 0), causing invalid IR; after computing real_stop for the
negative-step branch, clamp it to a non-negative integer (e.g., real_stop =
max(0, real_stop)) so empty iteration ranges yield zero iterations before it is
passed to tir.serial.

else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations)
Comment on lines +256 to +269

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard trip count when step cannot reach stop

The new stepped loop path computes real_stop with tir.ceildiv and passes it directly to tir.serial. When the step sign is incompatible with the start/stop ordering (e.g. range(10, 0, 2) or range(0, 10, -2)), (stop - start) and step have opposite signs, so real_stop becomes negative and we attempt to build a serial loop with a negative extent. TVM’s serial builder rejects negative extents, raising an error for loops that should simply execute zero iterations per Python semantics. Consider clamping the trip count to zero when no iterations are expected before creating the tir.serial frame.

Useful? React with 👍 / 👎.

with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v

def ctx_continue(self):
self.check_continue_break()
Expand Down Expand Up @@ -459,8 +487,9 @@ def arg(self, name, value):
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")

def override(self, name: str):
from tilelang.language import serial
if name == 'range':
return tir.serial
return serial
raise ValueError(f'Unknown override: {name}')


Expand Down
Loading