Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions testing/python/language/test_tilelang_language_frontend_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,37 @@ def foo() -> T.Tensor((128,), T.float32):
assert isinstance(foo, T.PrimFunc)


def test_serial_for_with_step():

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_stepped_serial(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(0, 10, 2):
T.device_assert(0 <= i < 10 and i % 2 == 0, "i out of range")
A[i] = 1.0
for i in range(1, 10, 2):
T.device_assert(1 <= i < 10 and i % 2 == 1, "i out of range")
A[i] = 2.0

ker = test_stepped_serial()
res = ker()
ref = torch.tensor([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"

@tilelang.jit(out_idx=-1)
@T.prim_func
def test_serial_step_neg(A: T.Tensor((10,), T.int32)):
with T.Kernel(1) as _:
for i in range(10, 0, -1):
T.device_assert(0 < i <= 10, "i out of range")
A[10 - i] = i

ker = test_serial_step_neg()
res = ker()
ref = torch.tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1], dtype=torch.int32, device='cuda')
assert torch.all(res == ref), f"Expected {ref}, but got {res}"


if __name__ == '__main__':
tilelang.testing.main()
4 changes: 1 addition & 3 deletions tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
)
from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
Expand Down
108 changes: 108 additions & 0 deletions tilelang/language/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""The language interface for tl programs."""
from __future__ import annotations

from typing import Any
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from tilelang import _ffi_api


def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
Parameters
----------
extents : PrimExpr
The extents of the iteration.
coalesced_width : Optional[int]
The coalesced width of the parallel loop.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member


def Persistent(
domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)


def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
order: list[int] | None = None,
stage: list[int] | None = None,
sync: list[list[int]] | None = None,
group: list[list[int]] | None = None,
):
"""Tools to construct pipelined for loop.
Parameters
----------
start : PrimExpr
The minimum value of iteration.
stop : PrimExpr
The maximum value of iteration.
num_stages : int
The max number of buffer used between pipeline producers and consumers.
if num_stages is 0, pipeline will not be enabled.
Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
if order is None:
order = []
if stage is None:
stage = []
if sync is None:
sync = []
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)


def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
if step is None or step == 1:
return tb_tir.serial(start, stop, annotations=annotations)
else:
if stop is None:
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)
29 changes: 0 additions & 29 deletions tilelang/language/parallel.py

This file was deleted.

27 changes: 0 additions & 27 deletions tilelang/language/persistent.py

This file was deleted.

46 changes: 0 additions & 46 deletions tilelang/language/pipeline.py

This file was deleted.

38 changes: 31 additions & 7 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class BreakFrame(Frame):
...


@dataclass
class SerialForWithStep:
start: PrimExpr
stop: PrimExpr
step: PrimExpr
annotations: dict[str, Any] | None = None


ContinueOrBreak = ContinueFrame | BreakFrame
AnyFrame = tir.frame.IRBuilderFrame | Frame

Expand Down Expand Up @@ -236,12 +244,27 @@ def eval(self, val: Any):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v
if isinstance(it, SerialForWithStep):
real_stop = tir.ceildiv(it.stop - it.start, it.step)
if isinstance(it.step, (int, IntImm)):
value = it.step if isinstance(it.step, int) else it.step.value
if value < 0:
real_stop = tir.ceildiv(it.start - it.stop, -it.step)
else:
logger.warning(
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_frame = tir.serial(real_stop, annotations=it.annotations)
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
else:
if not isinstance(it, tir.frame.ForFrame):
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding")
with self.with_frame(it) as v:
yield v

def ctx_continue(self):
self.check_continue_break()
Expand Down Expand Up @@ -449,8 +472,9 @@ def arg(self, name, value):
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")

def override(self, name: str):
import tilelang.language as T
if name == 'range':
return tir.serial
return T.serial
raise ValueError(f'Unknown override: {name}')


Expand Down
Loading