Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions testing/python/issue/test_tilelang_issue_1601.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tilelang
import tilelang.testing
import tilelang.language as T


def test_issue_1601():
@tilelang.jit
def qwq():
@T.prim_func
def main(
A: T.Tensor((8,), T.float8_e4m3fn),
):
with T.Kernel(1, threads=32):
for i in T.vectorized(8):
A[i] = 0

return main

kernel = qwq()
assert "fp8_e4_t broadcast_var = fp8_e4_t(0x0p+0f/*0.000000e+00*/);" in kernel.get_kernel_source()


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import tilelang
import tilelang.language as T
import torch
import re
import pytest
import tilelang.testing
from tilelang import tvm as tvm
import tilelang as tl
from tilelang.utils.target import determine_target


@tilelang.jit
def qwq(dtype=torch.float8_e4m3fn):
@T.prim_func
def main(
A: T.Tensor((32,), dtype),
B: T.Tensor((16,), dtype),
C: T.Tensor((8,), dtype),
D: T.Tensor((4,), dtype),
E: T.Tensor((2,), dtype),
):
with T.Kernel(1, threads=32):
var = T.alloc_var(dtype, 1.0)
for i in T.vectorized(32):
A[i] = var
for i in T.vectorized(16):
B[i] = 13.5
for i in T.vectorized(8):
C[i] = 3.14
for i in T.vectorized(4):
D[i] = 2.72
for i in T.vectorized(2):
E[i] = 430

return main


@tilelang.testing.requires_cuda
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e8m0fnu, torch.float16])
def test_hoist_broadcast(dtype):
kernel = qwq(dtype)
print(kernel.get_kernel_source())
matches = re.findall(r"(\w+) broadcast_var(_[0-9]+)? = \1", kernel.get_kernel_source())
assert len(matches) == 4
a = torch.empty((32,), device="cuda", dtype=dtype)
b = torch.empty((16,), device="cuda", dtype=dtype)
c = torch.empty((8,), device="cuda", dtype=dtype)
d = torch.empty((4,), device="cuda", dtype=dtype)
e = torch.empty((2,), device="cuda", dtype=dtype)
kernel(a, b, c, d, e)


auto_target = tvm.target.Target(determine_target("auto"))


def _check(original, transformed):
mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main"))
mod = tvm.tir.transform.BindTarget(auto_target)(mod)
mod = tl.transform.HoistBroadcastValues()(mod)

transformed = tvm.IRModule.from_expr(transformed.with_attr("global_symbol", "main"))
transformed = tvm.tir.transform.BindTarget(auto_target)(transformed)

tvm.ir.assert_structural_equal(mod["main"], transformed["main"], True)


def test_transform_hoist():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
A_shared[0:8] = T.Broadcast(T.float8_e4m3fn(1.2), 8) + T.Broadcast(T.float8_e4m3fn(3.4), 8)

@T.prim_func
def after():
with T.Kernel(8):
A_shared = T.decl_buffer((256), T.float8_e4m3fn, scope="shared.dyn")
broadcast_var: T.float8_e4m3fn = T.float8_e4m3fn(1.2)
broadcast_var_1: T.float8_e4m3fn = T.float8_e4m3fn(3.4)
A_shared[0:8] = T.Broadcast(broadcast_var, 8) + T.Broadcast(broadcast_var_1, 8)

_check(before, after)


if __name__ == "__main__":
tilelang.testing.main()
3 changes: 3 additions & 0 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
device_mod = tilelang.transform.HoistBroadcastValues()(device_mod)

if target.kind.name == "cuda":
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda")
Expand All @@ -186,6 +187,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
device_mod = tilelang.transform.HoistBroadcastValues()(device_mod)

if target.kind.name == "cuda":
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile"
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
Expand Down
1 change: 1 addition & 0 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401
from .hoist_broadcast_values import HoistBroadcastValues # noqa: F401


def get_pass_context():
Expand Down
83 changes: 83 additions & 0 deletions tilelang/transform/hoist_broadcast_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from tvm import tir
from tvm.tir import (
BufferStore,
LetStmt,
Broadcast,
Var,
PrimFunc,
PyStmtExprMutator,
)
from tvm.tir.transform import prim_func_pass


@tir.functor.mutator
class HoistBroadcastValuesMutator(PyStmtExprMutator):
def __init__(self):
super().__init__()
# Temporary queue: used to store variables that need to be defined within the current statement.
self.pending_defs = []

def visit_broadcast_(self, op):
if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
# 1. Intercept Broadcast nodes.
# Extract the value to be hoisted into a variable.
val = self.visit_expr(op.value)
# 2. Create a new variable.
new_var = Var("broadcast_var", dtype=val.dtype)

# 3. Add the (variable, value) pair to the pending queue.
# Note: Do not create the LetStmt here; it must wrap the statement.
self.pending_defs.append((new_var, val))

# 4. Return a new Broadcast node, using the new variable to replace the original value.
return Broadcast(new_var, op.lanes)
return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes))

# Must intercept all Statements that might contain Expressions.
# Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt.
def visit_buffer_store_(self, op: BufferStore):
# 1. Clear the pending queue for the current statement context.
self.pending_defs = []

# 2. Visit child nodes normally (this will trigger visit_broadcast_).
new_indices = [self.visit_expr(idx) for idx in op.indices]
new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices)

# 3. Check if there are variables waiting to be defined.
if self.pending_defs:
# 4. Wrap the current statement with LetStmt.
# Order: Traverse in reverse to ensure the first definition wraps the outermost layer.
# Structure generated: Let my_var = val In BufferStore(...)
for var, val in reversed(self.pending_defs):
new_stmt = LetStmt(var, val, new_stmt)

# Clear the queue for the next statement.
self.pending_defs = []
return new_stmt


def HoistBroadcastValues():
"""
TVM Pass: HoistBroadcastValues.

This pass scans the TIR for Broadcast operations involving immediate constants (IntImm, FloatImm).
It extracts these constants into variables defined via LetStmt immediately surrounding
the statement where the broadcast occurs.

Example Transformation:
-----------------------
Before:
A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4)

After:
bv_3_14 = 3.14
bv_3_14_1 = 3.14
A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4)
"""

def pass_fn(func: PrimFunc, mod, ctx):
mutator = HoistBroadcastValuesMutator()
new_body = mutator.visit_stmt(func.body)
return func.with_body(new_body)

return prim_func_pass(pass_fn, opt_level=0)
Loading