diff --git a/src/op/builtin.cc b/src/op/builtin.cc index c3a4d15cf..7f890ee1f 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -288,6 +288,17 @@ TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr( TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr( "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(sync_warp).set_num_inputs(-1).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(pdl_trigger) + .set_num_inputs(0) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(pdl_sync).set_num_inputs(0).set_attr( + "TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(loop_break) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 16586d4f9..8480e9ac7 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -442,6 +442,30 @@ TVM_DLL const Op &wait_wgmma(); */ TVM_DLL const Op &sync_grid(); +/*! + * \brief Synchronize all threads in a warp + * + * sync_warp() + * + */ +TVM_DLL const Op &sync_warp(); + +/*! + * \brief Programmatic dependency trigger. + * + * pdl_trigger() + * + */ +TVM_DLL const Op &pdl_trigger(); + +/*! + * \brief Programmatic grid dependency synchronization. + * + * pdl_sync() + * + */ +TVM_DLL const Op &pdl_sync(); + /*! * \brief tvm intrinsic for loop continue * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 5ab7ecc0f..0dfb85341 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1877,6 +1877,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { this->need_cooperative_groups_ = true; this->PrintIndent(); this->stream << "cooperative_groups::this_grid().sync();\n"; + } else if (op->op.same_as(tl::sync_warp())) { + this->PrintIndent(); + this->stream << "__syncwarp("; + if (!op->args.empty()) { + this->stream << this->PrintExpr(op->args[0]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::pdl_trigger())) { + this->PrintIndent(); + this->stream << "cudaTriggerProgrammaticLaunchCompletion();\n"; + } else if (op->op.same_as(tl::pdl_sync())) { + this->PrintIndent(); + this->stream << "cudaGridDependencySynchronize();\n"; } else if (op->op.same_as(tl::loop_break())) { this->PrintIndent(); this->stream << "break;\n"; diff --git a/testing/python/language/test_tilelang_language_warp_sync.py b/testing/python/language/test_tilelang_language_warp_sync.py new file mode 100644 index 000000000..4c9aaff2a --- /dev/null +++ b/testing/python/language/test_tilelang_language_warp_sync.py @@ -0,0 +1,62 @@ +import tilelang +import tilelang.language as T +import torch +from tvm import tir +import tilelang.testing + + +@tilelang.jit +def kernel_with_warp_sync(): + @T.prim_func + def main( + A: T.Tensor((1,), "int32"), + B: T.Tensor((1,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + if tx == 0: + tir.call_extern("void", "__nanosleep", 100) + A[0] = -1 + T.sync_warp() + if tx == 1: + B[0] = A[0] + + return main + + +@tilelang.testing.requires_cuda +def test_warp_sync(): + a = torch.empty((1), device="cuda", dtype=torch.int32) + b = torch.empty((1), device="cuda", dtype=torch.int32) + kernel = kernel_with_warp_sync() + assert "__syncwarp" in kernel.get_kernel_source() + kernel(a, b) + assert b[0] == -1 + + +@tilelang.jit +def kernel_with_shfl_sync(): + @T.prim_func + def main( + A: T.Tensor((32,), "int32"), + ): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding() + val = tx * 10 + broadcast = T.shfl_sync(0xFFFFFFFF, val, 31) + A[tx] = broadcast + + return main + + +@tilelang.testing.requires_cuda +def test_shfl_sync(): + a = torch.empty((32), device="cuda", dtype=torch.int32) + kernel = kernel_with_shfl_sync() + assert "__shfl_sync" in kernel.get_kernel_source() + kernel(a) + assert torch.all(a == 310) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 2932656ca..1e16f03c4 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -709,6 +709,20 @@ def sync_threads(barrier_id: int = None, arrive_count: int = None): return tir.call_intrin("int32", "tir.tvm_storage_sync", "shared", *args) +def sync_warp(mask: int = None): + """Synchronize all threads in a warp.""" + if mask is not None: + return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp"), mask) + return tir.call_intrin("void", tir.op.Op.get("tl.sync_warp")) + + +def shfl_sync(mask: int, value: int | PrimExpr, srcLane: int, width: int = None): + """Receives data from a thread in the same warp.""" + if width is None: + return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane) + return tir.call_extern(value.dtype, "__shfl_sync", mask, value, srcLane, width) + + def sync_global(): """Synchronize all threads in the entire grid.""" tx, ty, tz = get_thread_bindings() diff --git a/tilelang/language/pdl.py b/tilelang/language/pdl.py index c1b3d7d07..ad8c66ce2 100644 --- a/tilelang/language/pdl.py +++ b/tilelang/language/pdl.py @@ -8,14 +8,14 @@ def pdl_trigger(): - return tir.call_extern( - "int32", # cudaError_t - "cudaTriggerProgrammaticLaunchCompletion", + return tir.call_intrin( + "void", + tir.op.Op.get("tl.pdl_trigger"), ) def pdl_sync(): - return tir.call_extern( - "int32", # cudaError_t - "cudaGridDependencySynchronize", + return tir.call_intrin( + "void", + tir.op.Op.get("tl.pdl_sync"), )