Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/gemm_sp/example_custom_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,9 @@ def kernel(
for tm in T.Parallel(block_M):
for g_i in range(0, block_K // group):
a_k = g_i * group
T.clear(non_zero_cnt)
T.clear(non_zero_elt_log_idx)
non_zero_cnt[0] = 0
for i in range(elem):
non_zero_elt_log_idx[i] = 0
for i in range(group):
val = A_shared[tm, a_k + i]
if val != 0.0:
Expand Down
12 changes: 11 additions & 1 deletion src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,12 @@ std::string CodeGenTileLangCUDA::Finish() {
void CodeGenTileLangCUDA::VisitStmt_(const tir::ForNode *op) {
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
if (unroll_factor.count(op->loop_var.get())) {
stream << "#pragma unroll "
<< PrintExpr(unroll_factor[op->loop_var.get()]) << "\n";
} else {
stream << "#pragma unroll\n";
}
}
std::string extent =
PrintExpr(arith::Analyzer().Simplify(op->extent + op->min));
Expand Down Expand Up @@ -2661,7 +2666,12 @@ void CodeGenTileLangCUDA::VisitStmt_(const AttrStmtNode *op) {
this->stream << "const dim3 blockIdx = " << pattern->value << "();\n";
this->VisitStmt(op->body);
return;
} else if (op->attr_key == "pragma_unroll_factor") {
const IntImmNode *factor = op->value.as<IntImmNode>();
ICHECK(factor);
unroll_factor[op->node.as<VarNode>()] = Downcast<IntImm>(factor);
}

CodeGenC::VisitStmt_(op);
}

Expand Down
1 change: 1 addition & 0 deletions src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class CodeGenTileLangCUDA final : public CodeGenC {

std::unordered_map<const VarNode *, std::string> fragment_shapes;
std::unordered_map<const VarNode *, std::string> fragment_layouts;
std::unordered_map<const VarNode *, IntImm> unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
Comment on lines +143 to 145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate friend declaration detected.

The friend void PrintConst(...) declaration appears twice in this class (lines 81-82 and lines 144-145). The duplicate on lines 144-145 should be removed.

   std::unordered_map<const VarNode *, std::string> fragment_shapes;
   std::unordered_map<const VarNode *, std::string> fragment_layouts;
   std::unordered_map<const VarNode *, IntImm> unroll_factor;
-  friend void PrintConst(const FloatImmNode *op, std::ostream &os,
-                         CodeGenTileLangCUDA *p);
   void PrintWmmaScope(const std::string &scope, DataType t,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::unordered_map<const VarNode *, IntImm> unroll_factor;
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
std::unordered_map<const VarNode *, IntImm> unroll_factor;
🤖 Prompt for AI Agents
In src/target/codegen_cuda.h around lines 143-145, there is a duplicate friend
declaration "friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);" that already exists earlier (lines 81-82); delete the
duplicate declaration at lines 144-145 so the class only contains the single
friend declaration at its original location.

void PrintWmmaScope(const std::string &scope, DataType t,
Expand Down
37 changes: 37 additions & 0 deletions testing/python/language/test_tilelang_language_unroll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T


def test_unroll_with_step():

@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)

for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, step=4):
A[0, i] = 1.0

kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll" in kernel.get_kernel_source()


def test_unroll_with_unroll_factor():

@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)

for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
for i in T.unroll(0, 16, unroll_factor=4):
A[0, i] = 1.0

kernel = tilelang.compile(main, target="cuda")
assert "#pragma unroll 4" in kernel.get_kernel_source()


if __name__ == "__main__":
tilelang.testing.main()
10 changes: 9 additions & 1 deletion tilelang/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
LocalBuffer, # noqa: F401
Ref, # noqa: F401
)
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .loop import (
Parallel, # noqa: F401
Persistent, # noqa: F401
Pipelined, # noqa: F401
serial, # noqa: F401
unroll, # noqa: F401
Serial, # noqa: F401
Unroll, # noqa: F401
)
from .frame import has_let_value, get_let_value # noqa: F401
from .math_intrinsics import * # noqa: F401
from .kernel import (
Expand Down
2 changes: 1 addition & 1 deletion tilelang/language/experimental/gemm_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def legalize_arguments(arg: tir.Buffer | tir.Var):
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
return tir.call_intrin(
"handle",
tir.op.Op.get("tl.gemm_sp_py"),
tir.op.Op.get("tl.tileop.gemm_sp_py"),
A_arg,
E_arg,
B_arg,
Expand Down
72 changes: 70 additions & 2 deletions tilelang/language/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from tvm import tir
from tvm.tir import IntImm
import tvm.script.ir_builder.tir as tb_tir
from .v2.builder import SerialForWithStep
from .v2.builder import SerialForWithStep, UnrollForWithStep
from tilelang import _ffi_api
from tvm.script.ir_builder.tir import frame


def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
Expand Down Expand Up @@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
annotations: dict[str, Any] | None = None):
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
step_is_one = False
step_is_one |= isinstance(step, int) and step == 1
step_is_one |= isinstance(step, IntImm) and step.value == 1
Expand All @@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr,
stop = start
start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0
return SerialForWithStep(start, stop, step, annotations=annotations)


def unroll(start: tir.PrimExpr,
stop: tir.PrimExpr | None = None,
step: tir.PrimExpr | None = None,
*,
explicit: bool = False,
unroll_factor: int | None = None,
annotations: dict[str, Any] | None = None) -> frame.ForFrame:
"""The unrolled For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

step : PrimExpr
The step size of the iteration.

explicit : bool
Whether to explicitly unroll the loop.

unroll_factor : int
The unroll factor of the loop.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""

step_is_one = False
if stop is None:
stop = start
if hasattr(start, "dtype"):
start = IntImm(start.dtype, 0)
else:
start = 0

# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": explicit}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", explicit)

if unroll_factor is not None:
# check pragma_unroll_explicit must be False
if annotations.get("pragma_unroll_explicit", True):
raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
annotations.update({"pragma_unroll_factor": unroll_factor})

Comment on lines +157 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify pragma_unroll_explicit constraint and align error message/comment

The comment says pragma_unroll_explicit “must be False” when unroll_factor is set, but:

  • The code enforces exactly that (it raises when the flag is True), and
  • The error message says “must be True”, which contradicts both the comment and the actual condition and triggers TRY003 due to length.

Consider tightening this block:

-    # Ensure annotations has {"pragma_unroll_explicit": True} by default
+    # Ensure annotations has {"pragma_unroll_explicit": explicit} by default
@@
-    if unroll_factor is not None:
-        # check pragma_unroll_explicit must be False
-        if annotations.get("pragma_unroll_explicit", True):
-            raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
-        annotations.update({"pragma_unroll_factor": unroll_factor})
+    if unroll_factor is not None:
+        # require non‑explicit unroll when using a factor
+        if annotations.get("pragma_unroll_explicit", False):
+            raise ValueError("unroll_factor requires pragma_unroll_explicit=False")
+        annotations["pragma_unroll_factor"] = unroll_factor

This makes the requirement obvious, fixes the contradictory message, and shortens the exception text to satisfy the TRY003 hint.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Ensure annotations has {"pragma_unroll_explicit": True} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": explicit}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", explicit)
if unroll_factor is not None:
# check pragma_unroll_explicit must be False
if annotations.get("pragma_unroll_explicit", True):
raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
annotations.update({"pragma_unroll_factor": unroll_factor})
# Ensure annotations has {"pragma_unroll_explicit": explicit} by default
if annotations is None:
annotations = {"pragma_unroll_explicit": explicit}
else:
# Add "pragma_unroll_explicit": True if not already present
annotations = dict(annotations)
annotations.setdefault("pragma_unroll_explicit", explicit)
if unroll_factor is not None:
# require non‑explicit unroll when using a factor
if annotations.get("pragma_unroll_explicit", False):
raise ValueError("unroll_factor requires pragma_unroll_explicit=False")
annotations["pragma_unroll_factor"] = unroll_factor
🧰 Tools
🪛 Ruff (0.14.6)

168-168: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/language/loop.py around lines 157 to 170, the inline comment and the
raised ValueError contradict each other: the code currently raises when
annotations.get("pragma_unroll_explicit", True) is True but the error text says
"must be True" and the comment wording is unclear; update the comment to state
clearly that pragma_unroll_explicit must be False when unroll_factor is
provided, and change the ValueError text to a short, correct message like
"pragma_unroll_explicit must be False" (keep it brief to satisfy TRY003).

if step is None or step_is_one:
return tb_tir.unroll(start, stop, annotations=annotations)
else:
return UnrollForWithStep(start, stop, step, annotations=annotations)

Comment on lines +149 to +175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix step_is_one handling so step=1 uses the simple unroll path

Right now step_is_one is always False, so step=1 will go through UnrollForWithStep instead of the plain tb_tir.unroll path, unlike serial(...). This is almost certainly unintended and diverges from the serial helper’s behavior.

You can mirror the serial implementation by actually computing step_is_one:

-    step_is_one = False
-    if stop is None:
-        stop = start
-        if hasattr(start, "dtype"):
-            start = IntImm(start.dtype, 0)
-        else:
-            start = 0
-
-    # Ensure annotations has {"pragma_unroll_explicit": True} by default
+    if stop is None:
+        stop = start
+        if hasattr(start, "dtype"):
+            start = IntImm(start.dtype, 0)
+        else:
+            start = 0
+
+    step_is_one = False
+    step_is_one |= isinstance(step, int) and step == 1
+    step_is_one |= isinstance(step, IntImm) and getattr(step, "value", None) == 1
+
+    # Ensure annotations has {"pragma_unroll_explicit": explicit} by default
@@
-    if step is None or step_is_one:
-        return tb_tir.unroll(start, stop, annotations=annotations)
-    else:
-        return UnrollForWithStep(start, stop, step, annotations=annotations)
+    if step is None or step_is_one:
+        return tb_tir.unroll(start, stop, annotations=annotations)
+    else:
+        return UnrollForWithStep(start, stop, step, annotations=annotations)

This keeps step=None/step=1 on the simple path and only uses UnrollForWithStep when the step is genuinely non‑unit.

🧰 Tools
🪛 Ruff (0.14.6)

168-168: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/language/loop.py around lines 149 to 175, the variable step_is_one
is left False so step=1 incorrectly bypasses the simple unroll path; compute
step_is_one based on the passed step (mirror serial: treat None or a unit
constant/IntImm of value 1 as "one") before the branch so that step=None or
step==1 use tb_tir.unroll, and only non‑unit steps go to UnrollForWithStep; also
ensure the dtype/IntImm case is handled consistently with how start was
normalized.


Serial = serial
Unroll = unroll
16 changes: 14 additions & 2 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class SerialForWithStep:
annotations: dict[str, Any] | None = None


@dataclass
class UnrollForWithStep(SerialForWithStep):
...


# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
Expand Down Expand Up @@ -270,7 +275,7 @@ def eval(self, val: Any):
def ctx_for(self, it):
self.check_continue_break()
it = unwrap_expr(it)
if isinstance(it, SerialForWithStep):
if isinstance(it, (SerialForWithStep, UnrollForWithStep)):
# Validate and compute the trip count before constructing the frame
if isinstance(it.step, (int, IntImm)):
step_value = it.step if isinstance(it.step, int) else it.step.value
Expand All @@ -285,7 +290,14 @@ def ctx_for(self, it):
f'Using a non-constant step `{it.step}` in stepped serial may lead to undefined behavior in tilelang'
)
real_stop = tir.ceildiv(it.stop - it.start, it.step)
real_frame = tir.serial(real_stop, annotations=it.annotations)
if isinstance(it, UnrollForWithStep):
real_frame = tir.unroll(real_stop, annotations=it.annotations)
elif isinstance(it, SerialForWithStep):
real_frame = tir.serial(real_stop, annotations=it.annotations)
else:
raise TypeError(
f"Invalid for loop, got {it}({type(it)}), expect one of the following: "
"range, T.serial, T.unroll, T.grid, T.parallel, T.vectorized, T.thread_binding")
with self.with_frame(real_frame) as v:
IRBuilder.name('_tmp', v)
yield it.start + v * it.step
Expand Down
Loading