diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index fea547b6e..4a5290c28 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -5,8 +5,6 @@ from tilelang.contrib import nvcc import argparse -tilelang.disable_cache() - @tilelang.jit( out_idx=[3, 4], diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index a9f45e077..1bc8fd1eb 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -508,8 +508,8 @@ def forward( total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) - o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) + kernel = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index 1672dbfb8..fbee75807 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -1,8 +1,6 @@ import tilelang import tilelang.language as T -tilelang.disable_cache() - # add decorator @tilelang.jit if you want to return a torch function # @tilelang.jit @@ -47,7 +45,6 @@ def main( def main(M=16384, N=16384, K=16384): - tilelang.disable_cache() block_M = 128 block_N = 128 block_K = 64 @@ -83,7 +80,6 @@ def main(M=16384, N=16384, K=16384): def run_regression_perf(M=16384, N=16384, K=16384): - tilelang.disable_cache() block_M = 128 block_N = 128 block_K = 64 diff --git a/src/op/builtin.cc b/src/op/builtin.cc index a0ee8acd8..7f65d9300 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -203,6 +203,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(ptx_cp_async) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_TL_BUILTIN(fence_proxy_async) .set_num_inputs(0) .set_attr("TCallEffectKind", diff --git a/src/op/builtin.h b/src/op/builtin.h index 0e39e9ad4..e12a3789c 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -304,6 +304,15 @@ TVM_DLL const Op &ptx_stmatrix(); */ TVM_DLL const Op &ptx_cp_async_barrier_noinc(); +/*! + * \brief TileLang intrinsic for PTX async copy from global to shared memory + * + * ptx_cp_async(dst_access_ptr, src_access_ptr, bytes) + * ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate) + * + */ +TVM_DLL const Op &ptx_cp_async(); + /*! * \brief Pack two b16 value into a b32 value * diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 5a4243471..4a7b6bf46 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -1452,23 +1452,48 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { return ss.str(); }; if (op->op.same_as(builtin::ptx_cp_async())) { + // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, + // args[3] = predicate (optional) + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, bytes, [predicate])"; + std::string dst = this->PrintExpr(op->args[0]); - std::string dst_offset = this->PrintExpr(op->args[1]); - std::string src = this->PrintExpr(op->args[2]); - std::string src_offset = this->PrintExpr(op->args[3]); - std::string size = this->PrintExpr(op->args[4]); - // use size of argument list to indicate whether or not to use predicated - // cp.async - if (op->args.size() == 5) { - this->PrintIndent(); - this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+" - << dst_offset << ", " << src << "+" << src_offset << ");\n"; + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + + this->PrintIndent(); + if (op->args.size() == 3) { + // Non-predicated version + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; } else { - std::string condition = this->PrintExpr(op->args[5]); - this->PrintIndent(); + // Predicated version + std::string condition = this->PrintExpr(op->args[3]); + this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst + << ", " << src << ", " << condition << ");\n"; + } + } else if (op->op.same_as(tl::ptx_cp_async())) { + // TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr, + // args[2] = bytes, args[3] = predicate (optional) + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, bytes, [predicate])"; + + std::string dst = this->PrintExpr(op->args[0]); + std::string src = this->PrintExpr(op->args[1]); + std::string size = this->PrintExpr(op->args[2]); + + this->PrintIndent(); + if (op->args.size() == 3) { + // Non-predicated version + this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src + << ");\n"; + } else { + // Predicated version + std::string condition = this->PrintExpr(op->args[3]); this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst - << "+" << dst_offset << ", " << src << "+" << src_offset - << ", " << condition << ");\n"; + << ", " << src << ", " << condition << ");\n"; } } else if (op->op.same_as(builtin::ptx_commit_group())) { print_extern_call_stmt("tl::cp_async_commit"); @@ -2276,22 +2301,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; - } else if (op->op.same_as(builtin::ptx_cp_async())) { - std::string dst = this->PrintExpr(op->args[0]); - std::string dst_offset = this->PrintExpr(op->args[1]); - std::string src = this->PrintExpr(op->args[2]); - std::string src_offset = this->PrintExpr(op->args[3]); - std::string size = this->PrintExpr(op->args[4]); - need_cast_smem_ptr_to_int_ = true; - // use size of argument list to indicate whether or not to use predicated - // cp.async - if (op->args.size() == 5) { - this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, - size); - } else { - this->stream << PrintPredicatedCpAsyncAssembly( - dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5])); - } } else if (op->op.same_as(builtin::ptx_cp_async_bulk())) { need_cast_smem_ptr_to_int_ = true; std::string dst = this->PrintExpr(op->args[0]); diff --git a/src/target/ptx.cc b/src/target/ptx.cc index 53f83ded9..83bb1096c 100644 --- a/src/target/ptx.cc +++ b/src/target/ptx.cc @@ -1341,9 +1341,13 @@ std::string PrintCpAsyncAssembly(const std::string &shared_ptr, )"; Replacer replacer; replacer.register_rule("{smem_addr}", - shared_ptr + " + " + shared_elem_offset); + shared_elem_offset.empty() + ? shared_ptr + : shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", - global_ptr + " + " + global_elem_offset); + global_elem_offset.empty() + ? global_ptr + : global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); asm_code = replacer.rewrite(asm_code); @@ -1396,9 +1400,13 @@ std::string PrintPredicatedCpAsyncAssembly( Replacer replacer; replacer.register_rule("{smem_addr}", - shared_ptr + " + " + shared_elem_offset); + shared_elem_offset.empty() + ? shared_ptr + : shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", - global_ptr + " + " + global_elem_offset); + global_elem_offset.empty() + ? global_ptr + : global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); replacer.register_rule("{store_shared}", store_shared); diff --git a/src/target/ptx.h b/src/target/ptx.h index 566cded6f..85d9b947b 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -193,9 +193,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, /*! * \brief Print ptx cp.async assembly string given parameters. * \param shared_ptr: The pointer to the destination shared memory. - * \param shared_elem_offset: The offset into the shared memory. + * \param shared_elem_offset: The offset into the shared memory (empty for no + * offset). * \param global_ptr: The pointer to the global memory. - * \param global_elem_offset: The offset into the global memory. + * \param global_elem_offset: The offset into the global memory (empty for no + * offset). * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. */ std::string PrintCpAsyncAssembly(const std::string &shared_ptr, @@ -204,12 +206,27 @@ std::string PrintCpAsyncAssembly(const std::string &shared_ptr, const std::string &global_elem_offset, const std::string &bytes); +/*! + * \brief Print ptx cp.async assembly string given parameters (no offset + * version). + * \param shared_ptr: The pointer to the destination shared memory. + * \param global_ptr: The pointer to the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + */ +inline std::string PrintCpAsyncAssembly(const std::string &shared_ptr, + const std::string &global_ptr, + const std::string &bytes) { + return PrintCpAsyncAssembly(shared_ptr, "", global_ptr, "", bytes); +} + /*! * \brief Print predicated ptx cp.async assembly string given parameters. * \param shared_ptr: The pointer to the destination shared memory. - * \param shared_elem_offset: The offset into the shared memory. + * \param shared_elem_offset: The offset into the shared memory (empty for no + * offset). * \param global_ptr: The pointer to the global memory. - * \param global_elem_offset: The offset into the global memory. + * \param global_elem_offset: The offset into the global memory (empty for no + * offset). * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. * \param predicate_value: The value of predicate `@p`. */ @@ -218,6 +235,21 @@ std::string PrintPredicatedCpAsyncAssembly( const std::string &global_ptr, const std::string &global_elem_offset, const std::string &bytes, const std::string &predicate_value); +/*! + * \brief Print predicated ptx cp.async assembly string given parameters (no + * offset version). + * \param shared_ptr: The pointer to the destination shared memory. + * \param global_ptr: The pointer to the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +inline std::string PrintPredicatedCpAsyncAssembly( + const std::string &shared_ptr, const std::string &global_ptr, + const std::string &bytes, const std::string &predicate_value) { + return PrintPredicatedCpAsyncAssembly(shared_ptr, "", global_ptr, "", bytes, + predicate_value); +} + /*! * \brief Print ptx async copy from global to shared memory using cp.async.bulk * \param shared_ptr: The pointer to the destination shared memory. diff --git a/src/transform/inject_ptx_async_copy.cc b/src/transform/inject_ptx_async_copy.cc index 1fadefbf4..a62bac762 100644 --- a/src/transform/inject_ptx_async_copy.cc +++ b/src/transform/inject_ptx_async_copy.cc @@ -95,16 +95,31 @@ class PTXAsyncCopyInjector : public StmtMutator { if (indices_lanes == 1) { auto src_offset = load->indices[0]; auto dst_offset = store->indices[0]; - Array args = { - store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)}; - // use arguments size to indicate whether or not to use predicated - // cp.async + + // Calculate the number of elements based on bytes and dtype + int dst_elem_count = bytes / dst_elem_type->bytes(); + int src_elem_count = bytes / src_elem_type->bytes(); + + // Create access_ptr for destination (shared memory, write access) + auto dst_access_ptr = store->buffer.access_ptr( + 2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count)); + + // Create access_ptr for source (global memory, read access) + auto src_access_ptr = load->buffer.access_ptr( + 1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count)); + + ffi::Array cp_async_args; if (predicated) { - args.push_back(predicate_value); + // Predicated cp.async with 4 arguments + cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes), + predicate_value}; + } else { + // Non-predicated cp.async with 3 arguments + cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes)}; } return Evaluate(Call(store->buffer->dtype, - tvm::tir::builtin::ptx_cp_async(), args)); + tvm::tir::builtin::ptx_cp_async(), + cp_async_args)); } // Predicated load don't support vectorized indexing. @@ -134,14 +149,29 @@ class PTXAsyncCopyInjector : public StmtMutator { } return PrimExpr(); }(); + if (src_offset.defined() && dst_offset.defined()) { - return Evaluate(Call( - store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + // Calculate the number of elements based on bytes and dtype + int dst_elem_count = bytes / dst_elem_type->bytes(); + int src_elem_count = bytes / src_elem_type->bytes(); + + // Create access_ptr for destination (shared memory, write access) + auto dst_access_ptr = store->buffer.access_ptr( + 2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count)); + + // Create access_ptr for source (global memory, read access) + auto src_access_ptr = load->buffer.access_ptr( + 1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count)); + + ffi::Array cp_async_args{dst_access_ptr, src_access_ptr, + PrimExpr(bytes)}; + return Evaluate(Call(store->buffer->dtype, + tvm::tir::builtin::ptx_cp_async(), + cp_async_args)); } } else { - // Only some vectorized indexing patterns are supported for now. + // Predicated vectorized cp.async - extract offsets from vectorized + // indices auto src_offset = [=]() -> PrimExpr { if (load->indices[0]->IsInstance()) { return load->indices[0].as()->base; @@ -154,8 +184,7 @@ class PTXAsyncCopyInjector : public StmtMutator { return store->indices[0].as()->base; } else if (store->indices[0].as()) { // The case where the dst buffer is a byte buffer generated by - // merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) + - // x8(17408))] = A_global[ramp(...),1, 8)] + // merging dynamic shared memory. auto *add = store->indices[0].as(); if (!add->a->IsInstance()) return PrimExpr(); @@ -168,11 +197,31 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { - return Evaluate(Call( - store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes), - predicate_value})); + // Calculate the number of elements based on bytes and dtype + int dst_elem_count = bytes / dst_elem_type->bytes(); + int src_elem_count = bytes / src_elem_type->bytes(); + + // Create access_ptr for destination (shared memory, write access) + auto dst_access_ptr = store->buffer.access_ptr( + 2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count)); + + // Create access_ptr for source (global memory, read access) + auto src_access_ptr = load->buffer.access_ptr( + 1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count)); + + // Predicated vectorized cp.async with 4 arguments + ffi::Array cp_async_args{dst_access_ptr, src_access_ptr, + PrimExpr(bytes), + predicate_value}; + return Evaluate(Call(store->buffer->dtype, + tvm::tir::builtin::ptx_cp_async(), + cp_async_args)); + } else { + // If we can't extract offsets from vectorized indices, fall back + LOG(WARNING) + << "Cannot extract offsets from vectorized indices for " + "predicated cp.async, " + << "falling back to regular buffer store/load"; } } } diff --git a/src/transform/merge_shared_memory_allocations.cc b/src/transform/merge_shared_memory_allocations.cc index 096fbd928..4b0c07a7e 100644 --- a/src/transform/merge_shared_memory_allocations.cc +++ b/src/transform/merge_shared_memory_allocations.cc @@ -566,29 +566,45 @@ class SharedMemoryRewriter : public StmtExprMutator { {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async())) { - ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); - DataType dtype = op->dtype; - Var buffer = Downcast(op->args[0]); + ICHECK_EQ(op->args.size(), 3U) + << "ptx_cp_async expects 3 arguments (dst_access_ptr, " + "src_access_ptr, bytes)"; + + // Extract dst_access_ptr and check if it needs merging + Call dst_access_ptr = Downcast(op->args[0]); + ICHECK(dst_access_ptr->op.same_as(builtin::tvm_access_ptr())) + << "First argument must be tvm_access_ptr"; + + // tvm_access_ptr(ptype, data, offset, extent, rw_mask) + Var buffer = Downcast(dst_access_ptr->args[1]); if (!IsAppropriateSharedMemory(buffer)) { return StmtExprMutator::VisitExpr_(op); } + + DataType dtype = op->dtype; PrimExpr extra_offset = GetBufferOffset(buffer, dtype); - PrimExpr offset = this->VisitExpr(op->args[1]); + PrimExpr offset = this->VisitExpr(dst_access_ptr->args[2]); // the dst shared memory is a byte buffer generated by merging shared // memory. we need to multiply the offset index by the byte size of the // original value dtype, to get the correct offset of merged shared // buffer. int index_factor = dtype.bytes(); - if (op->args.size() == 5) - return Call(dtype, op->op, - {merged_buf_var_, - mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4]}); - else - return Call(dtype, op->op, - {merged_buf_var_, - mul(extra_offset + offset, PrimExpr(index_factor)), - op->args[2], op->args[3], op->args[4], op->args[5]}); + + // Create new dst_access_ptr with merged buffer and adjusted offset + auto new_dst_access_ptr = + Call(DataType::Handle(), builtin::tvm_access_ptr(), + { + dst_access_ptr->args[0], // ptype + merged_buf_var_, // merged buffer + mul(extra_offset + offset, + PrimExpr(index_factor)), // adjusted offset + dst_access_ptr->args[3], // extent + dst_access_ptr->args[4] // rw_mask + }); + + // Keep src_access_ptr and bytes unchanged + return Call(dtype, op->op, + {new_dst_access_ptr, op->args[1], op->args[2]}); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py index 9d232902c..33395a53d 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py @@ -269,7 +269,6 @@ def run_gemm_rs( num_stages, num_threads, ) - kernel = tilelang.compile( program, out_idx=[3], diff --git a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py index 533a62fc6..a815e8e32 100644 --- a/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py +++ b/testing/python/transform/test_tilelang_transform_inject_fence_proxy.py @@ -66,7 +66,11 @@ def before(): with T.Kernel(8): A_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") B_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn") - T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16) + T.ptx_cp_async( + T.tvm_access_ptr(T.type_annotation(T.uint8), A_shared.data, 0, 16, 2), + T.tvm_access_ptr(T.type_annotation(T.uint8), B_shared.data, 0, 16, 1), + 16, + ) T.fence_proxy_async() T.call_extern("handle", "generic_op") diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 20876a944..e71cb91d9 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1343,36 +1343,56 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme return _tvm_op.ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset) -def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes): +def ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate=None): """TVM intrinsic for ptx async copy from global to shared memory using cp.async https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async Parameters ---------- - dtype : str - The data type of the result. + dst_access_ptr : PrimExpr + The destination (shared memory) access pointer created by tvm_access_ptr. + Should include pointer, offset, extent, and write access flag (rw_mask=2). - shared_ptr : Var - The shared memory pointer variable. + src_access_ptr : PrimExpr + The source (global memory) access pointer created by tvm_access_ptr. + Should include pointer, offset, extent, and read access flag (rw_mask=1). - shared_offset : Expr - The offset of shared memory pointer. + bytes : int or PrimExpr + The number of bytes to copy (must be 4, 8, or 16). - global_ptr : Var - The global memory pointer variable. - - global_offset : Expr - The offset of global memory pointer. - - bytes : int - The data size to copy. + predicate : PrimExpr, optional + Optional predicate condition for conditional cp.async. When provided, the copy + will only be performed if the predicate evaluates to true. Otherwise, the + destination will be filled with zeros (default behavior of cp.async). Returns ------- call : PrimExpr The call expression. - """ - return _tvm_op.ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes) + + Examples + -------- + >>> # Copy 16 bytes from global to shared memory + >>> T.ptx_cp_async( + ... T.tvm_access_ptr(T.type_annotation(T.uint8), A_shared.data, 0, 16, 2), # dst + ... T.tvm_access_ptr(T.type_annotation(T.uint8), B_global.data, 0, 16, 1), # src + ... 16 # bytes + ... ) + >>> + >>> # Predicated cp.async (only copy if condition is true) + >>> T.ptx_cp_async( + ... T.tvm_access_ptr(T.type_annotation(T.uint8), A_shared.data, 0, 16, 2), + ... T.tvm_access_ptr(T.type_annotation(T.uint8), B_global.data, 0, 16, 1), + ... 16, + ... predicate=guard # only copy if guard is true + ... ) + """ + from tvm import tir + + if predicate is None: + return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes) + else: + return tir.call_intrin("", tir.op.Op.get("tl.ptx_cp_async"), dst_access_ptr, src_access_ptr, bytes, predicate) def ptx_cp_async_bulk(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id):