Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions testing/python/language/test_tilelang_language_warp_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import tilelang
import tilelang.language as T
import torch
from tvm import tir
import tilelang.testing


@tilelang.jit
def kernel_with_warp_sync():
@T.prim_func
def main(
A: T.Tensor((1,), "int32"),
B: T.Tensor((1,), "int32"),
):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding()
if tx == 0:
tir.call_extern("void", "__nanosleep", 100)
A[0] = -1
T.sync_warp()
if tx == 1:
B[0] = A[0]

return main


def test_warp_sync():
a = torch.empty((1), device="cuda", dtype=torch.int32)
b = torch.empty((1), device="cuda", dtype=torch.int32)
kernel = kernel_with_warp_sync()
assert "__syncwarp" in kernel.get_kernel_source()
kernel(a, b)
assert b[0] == -1


@tilelang.jit
def kernel_with_shfl_sync():
@T.prim_func
def main(
A: T.Tensor((32,), "int32"),
):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding()
val = tx * 10
broadcast = T.shfl_sync(0xFFFFFFFF, val, 31)
A[tx] = broadcast

return main


def test_shfl_sync():
a = torch.empty((32), device="cuda", dtype=torch.int32)
kernel = kernel_with_shfl_sync()
assert "__shfl_sync" in kernel.get_kernel_source()
kernel(a)
assert torch.all(a == 310)


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tilelang
import tilelang.language as T
import torch
import re
import pytest
import tilelang.testing


@tilelang.jit
def qwq(dtype=torch.float8_e4m3fn):
@T.prim_func
def main(
A: T.Tensor((32,), dtype),
B: T.Tensor((16,), dtype),
C: T.Tensor((8,), dtype),
D: T.Tensor((4,), dtype),
E: T.Tensor((2,), dtype),
):
with T.Kernel(1, threads=32):
var = T.alloc_var(dtype, 1.0)
for i in T.vectorized(32):
A[i] = var
for i in T.vectorized(16):
B[i] = 13.5
for i in T.vectorized(8):
C[i] = 3.14
for i in T.vectorized(4):
D[i] = 2.72
for i in T.vectorized(2):
E[i] = 430

return main


@tilelang.testing.requires_cuda
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e8m0fnu, torch.float16])
def test_broadcast(dtype):
kernel = qwq(dtype)
print(kernel.get_kernel_source())
matches = re.findall(r"(\w+) broadcast_var(_[0-9]+)? = \1", kernel.get_kernel_source())
assert len(matches) == 4
a = torch.empty((32,), device="cuda", dtype=dtype)
b = torch.empty((16,), device="cuda", dtype=dtype)
c = torch.empty((8,), device="cuda", dtype=dtype)
d = torch.empty((4,), device="cuda", dtype=dtype)
e = torch.empty((2,), device="cuda", dtype=dtype)
kernel(a, b, c, d, e)


if __name__ == "__main__":
tilelang.testing.main()
3 changes: 3 additions & 0 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
device_mod = tilelang.transform.HoistBroadcastValues()(device_mod)

if target.kind.name == "cuda":
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda")
Expand All @@ -208,6 +209,8 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
device_mod = tilelang.transform.LowerIntrin()(device_mod)
device_mod = tir.transform.Simplify()(device_mod)
device_mod = tilelang.transform.HoistBroadcastValues()(device_mod)

if target.kind.name == "cuda":
global_func = "target.build.tilelang_" + ("cutedsl" if "cutedsl" in target.keys else "cuda") + "_without_compile"
device_mod = tvm.ffi.get_global_func(global_func)(device_mod, target)
Expand Down
14 changes: 14 additions & 0 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,20 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
return tir.call_intrin("int32", "tir.tvm_storage_sync", "shared", *args)


def sync_warp(mask: int = None):
"""Synchronize all threads in a warp."""
if mask is not None:
return tir.call_extern("void", "__syncwarp", mask)
return tir.call_extern("void", "__syncwarp")


def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int = None):
"""Receives data from a thread in the same warp."""
if width is None:
return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane)
return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width)


def sync_global():
"""Synchronize all threads in the entire grid."""
tx, ty, tz = get_thread_bindings()
Expand Down
1 change: 1 addition & 0 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
from .add_bufstore_wrapper import AddWrapperForSingleBufStore # noqa: F401
from .hoist_broadcast_values import HoistBroadcastValues # noqa: F401


def get_pass_context():
Expand Down
83 changes: 83 additions & 0 deletions tilelang/transform/hoist_broadcast_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from tvm import tir
from tvm.tir import (
BufferStore,
LetStmt,
Broadcast,
Var,
PrimFunc,
PyStmtExprMutator,
)
from tvm.tir.transform import prim_func_pass


@tir.functor.mutator
class HoistBroadcastValuesMutator(PyStmtExprMutator):
def __init__(self):
super().__init__()
# Temporary queue: used to store variables that need to be defined within the current statement.
self.pending_defs = []

def visit_broadcast_(self, op):
if isinstance(op.value, (tir.IntImm, tir.FloatImm)):
# 1. Intercept Broadcast nodes.
# Extract the value to be hoisted into a variable.
val = self.visit_expr(op.value)
# 2. Create a new variable.
new_var = Var("broadcast_var", dtype=val.dtype)

# 3. Add the (variable, value) pair to the pending queue.
# Note: Do not create the LetStmt here; it must wrap the statement.
self.pending_defs.append((new_var, val))

# 4. Return a new Broadcast node, using the new variable to replace the original value.
return Broadcast(new_var, op.lanes)
return Broadcast(self.visit_expr(op.value), self.visit_expr(op.lanes))

# Must intercept all Statements that might contain Expressions.
# Examples: BufferStore, LetStmt, Evaluate, IfThenElse, AssertStmt.
def visit_buffer_store_(self, op: BufferStore):
# 1. Clear the pending queue for the current statement context.
self.pending_defs = []

# 2. Visit child nodes normally (this will trigger visit_broadcast_).
new_indices = [self.visit_expr(idx) for idx in op.indices]
new_stmt = BufferStore(op.buffer, self.visit_expr(op.value), new_indices)

# 3. Check if there are variables waiting to be defined.
if self.pending_defs:
# 4. Wrap the current statement with LetStmt.
# Order: Traverse in reverse to ensure the first definition wraps the outermost layer.
# Structure generated: Let my_var = val In BufferStore(...)
for var, val in reversed(self.pending_defs):
new_stmt = LetStmt(var, val, new_stmt)

# Clear the queue for the next statement.
self.pending_defs = []
return new_stmt


def HoistBroadcastValues():
"""
TVM Pass: HoistBroadcastValues.

This pass scans the TIR for Broadcast operations involving immediate constants (IntImm, FloatImm).
It extracts these constants into variables defined via LetStmt immediately surrounding
the statement where the broadcast occurs.

Example Transformation:
-----------------------
Before:
A[i] = B[i] + T.Broadcast(3.14, 4) + T.Broadcast(3.14, 4)

After:
bv_3_14 = 3.14
bv_3_14_1 = 3.14
A[i] = B[i] + T.Broadcast(bv_3_14, 4) + T.Broadcast(bv_3_14_1, 4)
"""

def pass_fn(func: PrimFunc, mod, ctx):
mutator = HoistBroadcastValuesMutator()
new_body = mutator.visit_stmt(func.body)
return func.with_body(new_body)

return prim_func_pass(pass_fn, opt_level=0)
Loading