Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,17 @@ TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr<TCallEffectKind>(
TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(sync_warp).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(pdl_trigger)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(pdl_sync).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(loop_break)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
24 changes: 24 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,30 @@ TVM_DLL const Op &wait_wgmma();
*/
TVM_DLL const Op &sync_grid();

/*!
* \brief Synchronize all threads in a warp
*
* sync_warp()
*
*/
TVM_DLL const Op &sync_warp();

/*!
* \brief Programmatic dependency trigger.
*
* pdl_trigger()
*
*/
TVM_DLL const Op &pdl_trigger();

/*!
* \brief Programmatic grid dependency synchronization.
*
* pdl_sync()
*
*/
TVM_DLL const Op &pdl_sync();

/*!
* \brief tvm intrinsic for loop continue
*
Expand Down
13 changes: 13 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::this_grid().sync();\n";
} else if (op->op.same_as(tl::sync_warp())) {
this->PrintIndent();
this->stream << "__syncwarp(";
if (!op->args.empty()) {
this->stream << this->PrintExpr(op->args[0]);
}
this->stream << ");\n";
} else if (op->op.same_as(tl::pdl_trigger())) {
this->PrintIndent();
this->stream << "cudaTriggerProgrammaticLaunchCompletion();\n";
} else if (op->op.same_as(tl::pdl_sync())) {
this->PrintIndent();
this->stream << "cudaGridDependencySynchronize();\n";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
Expand Down
62 changes: 62 additions & 0 deletions testing/python/language/test_tilelang_language_warp_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import tilelang
import tilelang.language as T
import torch
from tvm import tir
import tilelang.testing


@tilelang.jit
def kernel_with_warp_sync():
@T.prim_func
def main(
A: T.Tensor((1,), "int32"),
B: T.Tensor((1,), "int32"),
):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding()
if tx == 0:
tir.call_extern("void", "__nanosleep", 100)
A[0] = -1
T.sync_warp()
if tx == 1:
B[0] = A[0]

return main


@tilelang.testing.requires_cuda
def test_warp_sync():
a = torch.empty((1), device="cuda", dtype=torch.int32)
b = torch.empty((1), device="cuda", dtype=torch.int32)
kernel = kernel_with_warp_sync()
assert "__syncwarp" in kernel.get_kernel_source()
kernel(a, b)
assert b[0] == -1


@tilelang.jit
def kernel_with_shfl_sync():
@T.prim_func
def main(
A: T.Tensor((32,), "int32"),
):
with T.Kernel(1, threads=32):
tx = T.get_thread_binding()
val = tx * 10
broadcast = T.shfl_sync(0xFFFFFFFF, val, 31)
A[tx] = broadcast

return main


@tilelang.testing.requires_cuda
def test_shfl_sync():
a = torch.empty((32), device="cuda", dtype=torch.int32)
kernel = kernel_with_shfl_sync()
assert "__shfl_sync" in kernel.get_kernel_source()
kernel(a)
assert torch.all(a == 310)


if __name__ == "__main__":
tilelang.testing.main()
14 changes: 14 additions & 0 deletions tilelang/language/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,20 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None):
return tir.call_intrin("int32", "tir.tvm_storage_sync", "shared", *args)


def sync_warp(mask: int = None):
"""Synchronize all threads in a warp."""
if mask is not None:
return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp"), mask)
return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp"))


def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int = None):
"""Receives data from a thread in the same warp."""
if width is None:
return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane)
return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width)


def sync_global():
"""Synchronize all threads in the entire grid."""
tx, ty, tz = get_thread_bindings()
Expand Down
12 changes: 6 additions & 6 deletions tilelang/language/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@


def pdl_trigger():
return tir.call_extern(
"int32", # cudaError_t
"cudaTriggerProgrammaticLaunchCompletion",
return tir.call_intrin(
"void",
tir.op.Op.get("tl.pdl_trigger"),
)


def pdl_sync():
return tir.call_extern(
"int32", # cudaError_t
"cudaGridDependencySynchronize",
return tir.call_intrin(
"void",
tir.op.Op.get("tl.pdl_sync"),
)
Loading