-
Notifications
You must be signed in to change notification settings - Fork 448
[Enhancement] Implement dynamic unroll factor in CUDA code generation #1360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b42ec64
d1330f7
fbc927f
9a1a674
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| import tilelang.testing | ||
| from tilelang import tvm as tvm | ||
| from tilelang import language as T | ||
|
|
||
|
|
||
| def test_unroll_with_step(): | ||
|
|
||
| @T.prim_func | ||
| def main(A_ptr: T.handle): | ||
| A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) | ||
|
|
||
| for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): | ||
| for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): | ||
| for i in T.unroll(0, 16, step=4): | ||
| A[0, i] = 1.0 | ||
|
|
||
| kernel = tilelang.compile(main, target="cuda") | ||
| assert "#pragma unroll" in kernel.get_kernel_source() | ||
|
|
||
|
|
||
| def test_unroll_with_unroll_factor(): | ||
|
|
||
| @T.prim_func | ||
| def main(A_ptr: T.handle): | ||
| A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) | ||
|
|
||
| for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): | ||
| for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): | ||
| for i in T.unroll(0, 16, unroll_factor=4): | ||
| A[0, i] = 1.0 | ||
|
|
||
| kernel = tilelang.compile(main, target="cuda") | ||
| assert "#pragma unroll 4" in kernel.get_kernel_source() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,8 +4,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm import tir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.tir import IntImm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import tvm.script.ir_builder.tir as tb_tir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .v2.builder import SerialForWithStep | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from .v2.builder import SerialForWithStep, UnrollForWithStep | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tilelang import _ffi_api | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from tvm.script.ir_builder.tir import frame | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -97,7 +98,7 @@ def serial(start: tir.PrimExpr, | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop: tir.PrimExpr | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step: tir.PrimExpr | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations: dict[str, Any] | None = None): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations: dict[str, Any] | None = None) -> frame.ForFrame: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_is_one = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_is_one |= isinstance(step, int) and step == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_is_one |= isinstance(step, IntImm) and step.value == 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -108,3 +109,70 @@ def serial(start: tir.PrimExpr, | |||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop = start | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = IntImm(start.dtype, 0) if hasattr(start, "dtype") else 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return SerialForWithStep(start, stop, step, annotations=annotations) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def unroll(start: tir.PrimExpr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop: tir.PrimExpr | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step: tir.PrimExpr | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| explicit: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unroll_factor: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations: dict[str, Any] | None = None) -> frame.ForFrame: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """The unrolled For statement. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start : PrimExpr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The minimum value of iteration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop : PrimExpr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The maximum value of iteration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step : PrimExpr | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The step size of the iteration. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| explicit : bool | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Whether to explicitly unroll the loop. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| unroll_factor : int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The unroll factor of the loop. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations : Dict[str, Any] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The optional annotations of the For statement. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| res : frame.ForFrame | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The ForFrame. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| step_is_one = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if stop is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop = start | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(start, "dtype"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = IntImm(start.dtype, 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ensure annotations has {"pragma_unroll_explicit": True} by default | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if annotations is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations = {"pragma_unroll_explicit": explicit} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Add "pragma_unroll_explicit": True if not already present | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations = dict(annotations) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations.setdefault("pragma_unroll_explicit", explicit) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if unroll_factor is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # check pragma_unroll_explicit must be False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if annotations.get("pragma_unroll_explicit", True): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| annotations.update({"pragma_unroll_factor": unroll_factor}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+157
to
+170
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clarify The comment says
Consider tightening this block: - # Ensure annotations has {"pragma_unroll_explicit": True} by default
+ # Ensure annotations has {"pragma_unroll_explicit": explicit} by default
@@
- if unroll_factor is not None:
- # check pragma_unroll_explicit must be False
- if annotations.get("pragma_unroll_explicit", True):
- raise ValueError("pragma_unroll_explicit must be True when unroll_factor is not None")
- annotations.update({"pragma_unroll_factor": unroll_factor})
+ if unroll_factor is not None:
+ # require non‑explicit unroll when using a factor
+ if annotations.get("pragma_unroll_explicit", False):
+ raise ValueError("unroll_factor requires pragma_unroll_explicit=False")
+ annotations["pragma_unroll_factor"] = unroll_factorThis makes the requirement obvious, fixes the contradictory message, and shortens the exception text to satisfy the TRY003 hint. 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.6)168-168: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if step is None or step_is_one: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tb_tir.unroll(start, stop, annotations=annotations) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return UnrollForWithStep(start, stop, step, annotations=annotations) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+149
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix Right now You can mirror the - step_is_one = False
- if stop is None:
- stop = start
- if hasattr(start, "dtype"):
- start = IntImm(start.dtype, 0)
- else:
- start = 0
-
- # Ensure annotations has {"pragma_unroll_explicit": True} by default
+ if stop is None:
+ stop = start
+ if hasattr(start, "dtype"):
+ start = IntImm(start.dtype, 0)
+ else:
+ start = 0
+
+ step_is_one = False
+ step_is_one |= isinstance(step, int) and step == 1
+ step_is_one |= isinstance(step, IntImm) and getattr(step, "value", None) == 1
+
+ # Ensure annotations has {"pragma_unroll_explicit": explicit} by default
@@
- if step is None or step_is_one:
- return tb_tir.unroll(start, stop, annotations=annotations)
- else:
- return UnrollForWithStep(start, stop, step, annotations=annotations)
+ if step is None or step_is_one:
+ return tb_tir.unroll(start, stop, annotations=annotations)
+ else:
+ return UnrollForWithStep(start, stop, step, annotations=annotations)This keeps 🧰 Tools🪛 Ruff (0.14.6)168-168: Avoid specifying long messages outside the exception class (TRY003) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Serial = serial | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Unroll = unroll | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate friend declaration detected.
The
friend void PrintConst(...)declaration appears twice in this class (lines 81-82 and lines 144-145). The duplicate on lines 144-145 should be removed.📝 Committable suggestion
🤖 Prompt for AI Agents