diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 0d6d98e25574..65012c6c0f0f 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -663,8 +663,7 @@ TVM_DLL const Op& ptx_cp_async(); * Var global_ptr, * Expr global_offset, * size_t bytes, - * Var barrier_ptr, - * Expr barrier_offset); + * int barrier_id); */ TVM_DLL const Op& ptx_cp_async_bulk(); @@ -681,7 +680,7 @@ TVM_DLL const Op& ptx_wait_group(); /*! * \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive * - * ptx_cp_async_barrier(Var barrier_ptr, Expr barrier_offset) + * ptx_cp_async_barrier(int barrier_id) * */ TVM_DLL const Op& ptx_cp_async_barrier(); @@ -689,7 +688,7 @@ TVM_DLL const Op& ptx_cp_async_barrier(); /*! * \brief tvm intrinsics for ptx barrier initialization of thread count using mbarrier.init * - * ptx_init_barrier_thread_count(Var barrier_ptr, Expr barrier_offset, int thread_count) + * ptx_init_barrier_thread_count(int barrier_id, int thread_count) * */ TVM_DLL const Op& ptx_init_barrier_thread_count(); @@ -697,7 +696,7 @@ TVM_DLL const Op& ptx_init_barrier_thread_count(); /*! * \brief tvm intrinsics for ptx barrier arrival using mbarrier.arrive * - * ptx_arrive_barrier(Var barrier_ptr, Expr barrier_offset) + * ptx_arrive_barrier(int barrier_id) * */ TVM_DLL const Op& ptx_arrive_barrier(); @@ -705,7 +704,7 @@ TVM_DLL const Op& ptx_arrive_barrier(); /*! * \brief tvm intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx * - * ptx_arrive_barrier_expect_tx(Var barrier_ptr, Expr barrier_offset, int byte_count) + * ptx_arrive_barrier_expect_tx(int barrier_id, int byte_count) * */ TVM_DLL const Op& ptx_arrive_barrier_expect_tx(); @@ -713,11 +712,19 @@ TVM_DLL const Op& ptx_arrive_barrier_expect_tx(); /*! * \brief tvm intrinsics for ptx barrier wait using mbarrier.try_wait * - * ptx_wait_barrier(Var barrier_ptr, Expr barrier_offset) + * ptx_wait_barrier(int barrier_id) * */ TVM_DLL const Op& ptx_wait_barrier(); +/*! + * \brief tvm intrinsics to create N barriers + * + * ptx_wait_barrier(int barrier_count) + * + */ +TVM_DLL const Op& create_barriers(); + /*! * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. * For example, if each thread in a warp of size 32 has 4 elements from the result of diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 337e06089583..5471288878f5 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1849,6 +1849,7 @@ def wrapped(*args, **kwargs): ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier) ptx_arrive_barrier_expect_tx = _op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx) ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier) +create_barriers = _op_wrapper(_tir_op.create_barriers) assume = _op_wrapper(_tir_op.assume) undef = _op_wrapper(_tir_op.undef) TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) @@ -2125,6 +2126,7 @@ def wrapped(*args, **kwargs): "ptx_arrive_barrier", "ptx_arrive_barrier_expect_tx", "ptx_wait_barrier", + "create_barriers", "mma_store", "mma_fill", "vectorlow", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 762fcb599f40..f0500290b888 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -71,6 +71,7 @@ ptx_arrive_barrier, ptx_arrive_barrier_expect_tx, ptx_wait_barrier, + create_barriers, ) from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index cb9227e8f2ea..30e2a2948769 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1369,7 +1369,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by def ptx_cp_async_bulk( - dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_ptr, barrier_offset + dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id ): """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk @@ -1394,11 +1394,8 @@ def ptx_cp_async_bulk( bytes : int The data size to copy. - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. Returns ------- @@ -1413,8 +1410,7 @@ def ptx_cp_async_bulk( global_ptr, global_offset, bytes, - barrier_ptr, - barrier_offset, + barrier_id, ) @@ -1447,37 +1443,31 @@ def ptx_wait_group(num): return call_intrin("", "tir.ptx_wait_group", num) -def ptx_cp_async_barrier(barrier_ptr, barrier_offset): +def ptx_cp_async_barrier(barrier_id): """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive Parameters ---------- - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_cp_async_barrier", barrier_ptr, barrier_offset) + return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id) -def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count): +def ptx_init_barrier_thread_count(barrier_id, thread_count): """TVM intrinsic for ptx barrier initialization of thread count using mbarrier.init https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init Parameters ---------- - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. thread_count : int Number of threads expected to arrive at the barrier. @@ -1487,43 +1477,35 @@ def ptx_init_barrier_thread_count(barrier_ptr, barrier_offset, thread_count): call : PrimExpr The call expression. """ - return call_intrin( - "", "tir.ptx_init_barrier_thread_count", barrier_ptr, barrier_offset, thread_count - ) + return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id, thread_count) -def ptx_arrive_barrier(barrier_ptr, barrier_offset): +def ptx_arrive_barrier(barrier_id): """TVM intrinsic for ptx barrier arrival using mbarrier.arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive Parameters ---------- - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_arrive_barrier", barrier_ptr, barrier_offset) + return call_intrin("", "tir.ptx_arrive_barrier", barrier_id) -def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count): +def ptx_arrive_barrier_expect_tx(barrier_id, byte_count): """TVM intrinsic for ptx barrier arrival with expect tx using mbarrier.arrive.expect_tx https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-expect-tx-operation Parameters ---------- - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. byte_count : int Increases the tx count of the mbarrier object to track completion of @@ -1534,29 +1516,40 @@ def ptx_arrive_barrier_expect_tx(barrier_ptr, barrier_offset, byte_count): call : PrimExpr The call expression. """ - return call_intrin( - "", "tir.ptx_arrive_barrier_expect_tx", barrier_ptr, barrier_offset, byte_count - ) + return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id, byte_count) -def ptx_wait_barrier(barrier_ptr, barrier_offset): +def ptx_wait_barrier(barrier_id): """TVM intrinsic for ptx barrier wait using mbarrier.try_wait https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-mbarrier-try-wait Parameters ---------- - barrier_ptr : Var - The barrier shared memory pointer variable. - barrier_id : int - The offset of the barrier shared memory pointer. + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tir.ptx_wait_barrier", barrier_id) + + +def create_barriers(barrier_count): + """TVM intrinsic to create N barriers + + Parameters + ---------- + barrier_count : int + The number of barriers to create. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_wait_barrier", barrier_ptr, barrier_offset) + return call_intrin("", "tir.create_barriers", barrier_count) def vectorlow(dtype, vec): diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index d880b978b5b9..7639ce606563 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -968,10 +968,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); - std::string barrier_ptr = this->PrintExpr(op->args[5]); - std::string barrier_offset = this->PrintExpr(op->args[6]); - this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier_ptr, - barrier_offset); + int barrier_id = Downcast(op->args[5])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { @@ -979,31 +979,50 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << "__asm__ __volatile__(\"cp.async.wait_group " << n << ";\");\n\n"; } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barrier_ptr = this->PrintExpr(op->args[0]); - std::string barrier_offset = this->PrintExpr(op->args[1]); - this->stream << PrintCpAsyncBarrierAsm(barrier_ptr, barrier_offset); + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintCpAsyncBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_init_barrier_thread_count())) { need_cast_smem_ptr_to_int_ = true; - std::string barrier_ptr = this->PrintExpr(op->args[0]); - std::string barrier_offset = this->PrintExpr(op->args[1]); - std::string thread_count = this->PrintExpr(op->args[2]); - this->stream << PrintInitBarrierThreadCountAsm(barrier_ptr, barrier_offset, thread_count); + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + std::string thread_count = this->PrintExpr(op->args[1]); + this->stream << PrintInitBarrierThreadCountAsm(barrier, thread_count); } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barrier_ptr = this->PrintExpr(op->args[0]); - std::string barrier_offset = this->PrintExpr(op->args[1]); - this->stream << PrintArriveBarrierAsm(barrier_ptr, barrier_offset); + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintArriveBarrierAsm(barrier); } else if (op->op.same_as(builtin::ptx_arrive_barrier_expect_tx())) { need_cast_smem_ptr_to_int_ = true; - std::string barrier_ptr = this->PrintExpr(op->args[0]); - std::string barrier_offset = this->PrintExpr(op->args[1]); - std::string byte_count = this->PrintExpr(op->args[2]); - this->stream << PrintArriveBarrierExpectTxAsm(barrier_ptr, barrier_offset, byte_count); + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + std::string byte_count = this->PrintExpr(op->args[1]); + this->stream << PrintArriveBarrierExpectTxAsm(barrier, byte_count); } else if (op->op.same_as(builtin::ptx_wait_barrier())) { need_cast_smem_ptr_to_int_ = true; - std::string barrier_ptr = this->PrintExpr(op->args[0]); - std::string barrier_offset = this->PrintExpr(op->args[1]); - this->stream << PrintWaitBarrierAsm(barrier_ptr, barrier_offset); + int barrier_id = Downcast(op->args[0])->value; + CHECK(barrier_id < barrier_count_); + std::string barrier = barrier_name_ + "[" + std::to_string(barrier_id) + "]"; + this->stream << PrintWaitBarrierAsm(barrier); + } else if (op->op.same_as(builtin::create_barriers())) { + CHECK_EQ(barrier_count_, -1); + int barrier_count = Downcast(op->args[0])->value; + // pad barrier alignment to avoid runtime alignment errors + CHECK_EQ(barrier_alignment_bytes_ % sizeof(uint64_t), 0); + int barrier_alignment_count = barrier_alignment_bytes_ / sizeof(uint64_t); + if (barrier_count % barrier_alignment_count != 0) { + barrier_count = ((barrier_count / barrier_alignment_count) + 1) * barrier_alignment_count; + } + barrier_count_ = barrier_count; + this->stream << "__shared__ __align__(" << barrier_alignment_bytes_ << ") uint64_t " + << barrier_name_ << "[" << barrier_count << "];\n"; + this->stream << "for (int i = 0; i < " << barrier_count << "; ++i) { " << barrier_name_ + << "[i] = 0; }\n"; } else if (op->op.same_as(builtin::ptx_ldg32())) { /* asm volatile ( diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 797ac9936375..bc7b34b500d8 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -109,6 +109,14 @@ class CodeGenCUDA final : public CodeGenC { // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); + // The name of the barrier array in shared memory + const std::string barrier_name_ = "barrier"; + // The size of the barrier array in shared memory + int barrier_count_ = -1; + // The alignment of the barrier array in shared memory + // Set to 16 to maintain minimum alignment requirements for async bulk copy + const int barrier_alignment_bytes_ = 16; + std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index dd7c7cb7c402..ed6125e74cae 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -713,8 +713,7 @@ std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, const std::string& shared_elem_offset, const std::string& global_ptr, const std::string& global_elem_offset, const std::string& bytes, - const std::string& barrier_ptr, - const std::string& barrier_elem_offset) { + const std::string& barrier) { std::string asm_code = R"( { unsigned int smem_addr_int = cast_smem_ptr_to_int({smem_addr}); @@ -731,13 +730,12 @@ std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); asm_code = replacer.rewrite(asm_code); return asm_code; } -std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset) { +std::string PrintCpAsyncBarrierAsm(const std::string& barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -749,13 +747,12 @@ std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset, +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, const std::string& thread_count) { std::string predicated_asm_code = R"( { @@ -769,14 +766,13 @@ std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); replacer.register_rule("{thread_count}", thread_count); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintArriveBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset) { +std::string PrintArriveBarrierAsm(const std::string& barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -788,13 +784,12 @@ std::string PrintArriveBarrierAsm(const std::string& barrier_ptr, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset, +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, const std::string& byte_count) { std::string predicated_asm_code = R"( { @@ -808,14 +803,13 @@ std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); replacer.register_rule("{byte_count}", byte_count); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } -std::string PrintWaitBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset) { +std::string PrintWaitBarrierAsm(const std::string& barrier) { std::string predicated_asm_code = R"( { unsigned int barrier_addr_int = cast_smem_ptr_to_int({barrier}); @@ -828,7 +822,7 @@ std::string PrintWaitBarrierAsm(const std::string& barrier_ptr, )"; Replacer replacer; - replacer.register_rule("{barrier}", barrier_ptr + " + " + barrier_elem_offset); + replacer.register_rule("{barrier}", "&" + barrier); predicated_asm_code = replacer.rewrite(predicated_asm_code); return predicated_asm_code; } diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index a73180d40b77..13d2f3cefc29 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -115,60 +115,48 @@ std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, * \param global_ptr: The pointer to the global memory. * \param global_elem_offset: The offset into the global memory. * \param bytes: The number of bytes to copy. - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. */ std::string PrintCpAsyncBulkAsm(const std::string& shared_ptr, const std::string& shared_elem_offset, const std::string& global_ptr, const std::string& global_elem_offset, const std::string& bytes, - const std::string& barrier_ptr, - const std::string& barrier_elem_offset); + const std::string& barrier); /*! * \brief Print ptx async copy barrier using cp.async.mbarrier.arrive - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. */ -std::string PrintCpAsyncBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset); +std::string PrintCpAsyncBarrierAsm(const std::string& barrier); /*! * \brief Print ptx barrier initialization of thread count using mbarrier.init - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. * \param thread_count: The number of threads expected to arrive at the barrier. */ -std::string PrintInitBarrierThreadCountAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset, +std::string PrintInitBarrierThreadCountAsm(const std::string& barrier, const std::string& thread_count); /*! * \brief Print ptx barrier arrival using mbarrier.arrive - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. */ -std::string PrintArriveBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset); +std::string PrintArriveBarrierAsm(const std::string& barrier); /*! * \brief Print ptx barrier arrival with expect tx operation using mbarrier.arrive.expect_tx - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. * \param byte_count: Increases the the tx count of the mbarrier object to track completion of * addtional async transactions. */ -std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset, +std::string PrintArriveBarrierExpectTxAsm(const std::string& barrier, const std::string& byte_count); /*! * \brief Print ptx barrier wait using mbarrier.try_wait - * \param barrier_ptr: The pointer to the barrier in shared memory. - * \param barrier_elem_offset: The offset to the barrier in shared memory. + * \param barrier: The name of the barrier in shared memory. */ -std::string PrintWaitBarrierAsm(const std::string& barrier_ptr, - const std::string& barrier_elem_offset); +std::string PrintWaitBarrierAsm(const std::string& barrier); } // namespace codegen } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index a4116abf136f..1b80959b5705 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -297,15 +297,22 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async_barrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(ptx_init_barrier_thread_count) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(ptx_arrive_barrier_expect_tx) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(ptx_wait_barrier) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(create_barriers) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(mma_store) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("TScriptDtypePrintLocation", diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index e4922e1e0c76..7398ee781b9e 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -237,10 +237,7 @@ def test_op_ptx_cp_async(): def test_op_ptx_cp_async_bulk(): buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tir.decl_buffer([8], "float16", scope="local") - barrier = tir.decl_buffer([1], "uint64", scope="shared") - expr = tir.ptx_cp_async_bulk( - "float16", buffer_shared.data, 0, buffer_local.data, 0, 16, barrier.data, 0 - ) + expr = tir.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) assert expr.op.name == "tir.ptx_cp_async_bulk" @@ -255,30 +252,35 @@ def test_op_ptx_wait_group(): def test_op_ptx_cp_async_barrier(): - expr = tir.ptx_cp_async_barrier("barrier", 0) + expr = tir.ptx_cp_async_barrier(0) assert expr.op.name == "tir.ptx_cp_async_barrier" def test_op_ptx_init_barrier_thread_count(): - expr = tir.ptx_init_barrier_thread_count("barrier", 0, 32) + expr = tir.ptx_init_barrier_thread_count(0, 32) assert expr.op.name == "tir.ptx_init_barrier_thread_count" def test_op_ptx_arrive_barrier(): - expr = tir.ptx_arrive_barrier("barrier", 0) + expr = tir.ptx_arrive_barrier(0) assert expr.op.name == "tir.ptx_arrive_barrier" def test_op_ptx_arrive_barrier_expect_tx(): - expr = tir.ptx_arrive_barrier_expect_tx("barrier", 0, 32) + expr = tir.ptx_arrive_barrier_expect_tx(0, 32) assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx" def test_op_ptx_wait_barrier(): - expr = tir.ptx_wait_barrier("barrier", 0) + expr = tir.ptx_wait_barrier(0) assert expr.op.name == "tir.ptx_wait_barrier" +def test_op_create_barriers(): + expr = tir.create_barriers(16) + assert expr.op.name == "tir.create_barriers" + + def test_tir_op_vectorlow(): buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) vec = buffer.vload([0, 0], dtype="int8x16") diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index e6d3942ce500..d7600238542d 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -71,23 +71,13 @@ def ptx_cp_async_barrier( T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): - # Shared memory targets for cp.async.bulk must be 16 byte aligned - # Problem: CUDA codegen does not support allocation alignment - # Workaround: Ensure that `A_shared` occurs before `barrier` in program order - # by allocating and initializing `A_shared` before `barrier` - # which should result in `A_shared` being 16+ byte aligned - # given it will be the first shared memory allocation - # TODO(Straw) Add CUDA codegen support for allocation alignment A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") - A_shared[0, 0] = 0 - - barrier = T.alloc_buffer([1], "uint64", scope="shared") - barrier[0] = 0 T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) - T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) + T.evaluate(T.create_barriers(1, dtype="")) + T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype="")) for i in range(16): T.evaluate( @@ -96,9 +86,9 @@ def ptx_cp_async_barrier( ) ) - T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype="")) - T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype="")) - T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_cp_async_barrier(0, dtype="")) + T.evaluate(T.ptx_arrive_barrier(0, dtype="")) + T.evaluate(T.ptx_wait_barrier(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i] @@ -126,32 +116,20 @@ def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): - # Shared memory targets for cp.async.bulk must be 16 byte aligned - # Problem: CUDA codegen does not support allocation alignment - # Workaround: Ensure that `A_shared` occurs before `barrier` in program order - # by allocating and initializing `A_shared` before `barrier` - # which should result in `A_shared` being 16+ byte aligned - # given it will be the first shared memory allocation - # TODO(Straw) Add CUDA codegen support for allocation alignment - A_shared = T.alloc_buffer([32, 128], "float16", scope="shared", align=16) - A_shared[0, 0] = 0 - - barrier = T.alloc_buffer([1], "uint64", scope="shared") - barrier[0] = 0 + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) - T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) + T.evaluate(T.create_barriers(1, dtype="")) + T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype="")) T.evaluate( - T.ptx_cp_async_bulk( - A_shared.data, tx * 128, A.data, tx * 128, 256, barrier.data, 0, dtype="float16" - ) + T.ptx_cp_async_bulk(A_shared.data, tx * 128, A.data, tx * 128, 256, 0, dtype="float16") ) - T.evaluate(T.ptx_arrive_barrier_expect_tx(barrier.data, 0, 256, dtype="")) - T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_arrive_barrier_expect_tx(0, 256, dtype="")) + T.evaluate(T.ptx_wait_barrier(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i] diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index ff70eeae81ab..61f0892a9cf3 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -193,21 +193,21 @@ def ptx_global_to_shared_copy_fp32x1_barrier( T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): - barrier = T.alloc_buffer([1], "uint64", scope="shared") A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") + T.reads(A[0:32, 0:128]) - T.writes(B[0:32, 0:128], barrier[0:1]) + T.writes(B[0:32, 0:128]) - barrier[0] = 0 - T.evaluate(T.ptx_init_barrier_thread_count(barrier.data, 0, 32, dtype="")) + T.evaluate(T.create_barriers(1, dtype="")) + T.evaluate(T.ptx_init_barrier_thread_count(0, 32, dtype="")) T.attr("default", "async_scope", 1) for i in T.serial(128): A_shared[tx, i] = A[tx, i] - T.evaluate(T.ptx_cp_async_barrier(barrier.data, 0, dtype="")) - T.evaluate(T.ptx_arrive_barrier(barrier.data, 0, dtype="")) - T.evaluate(T.ptx_wait_barrier(barrier.data, 0, dtype="")) + T.evaluate(T.ptx_cp_async_barrier(0, dtype="")) + T.evaluate(T.ptx_arrive_barrier(0, dtype="")) + T.evaluate(T.ptx_wait_barrier(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i]