diff --git a/src/target/codegen_cutedsl.cc b/src/target/codegen_cutedsl.cc index 66e4a8c24..daa414d8d 100644 --- a/src/target/codegen_cutedsl.cc +++ b/src/target/codegen_cutedsl.cc @@ -312,23 +312,25 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op, }; if (op->op.same_as(builtin::ptx_cp_async())) { + // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, + // args[3] = predicate (optional) + ICHECK(op->args.size() == 3 || op->args.size() == 4) + << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " + "src_access_ptr, bytes, [predicate])"; + std::string dst = PrintExpr_(op->args[0]); - std::string dst_offset = PrintExpr_(op->args[1]); - std::string src = PrintExpr_(op->args[2]); - std::string src_offset = PrintExpr_(op->args[3]); - std::string size = PrintExpr_(op->args[4]); - // use size of argument list to indicate whether or not to use predicated - // cp.async - if (op->args.size() == 5) { - PrintIndent(); - stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset - << ", " << src << ", " << src_offset << ")\n"; + std::string src = PrintExpr_(op->args[1]); + std::string size = PrintExpr_(op->args[2]); + + this->PrintIndent(); + if (op->args.size() == 3) { + stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src + << ")\n"; } else { - std::string condition = PrintExpr_(op->args[5]); + std::string condition = PrintExpr_(op->args[3]); PrintIndent(); stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " - << dst_offset << ", " << src << ", " << src_offset << ", " - << condition << ")\n"; + << src << ", " << condition << ")\n"; } } else if (op->op.same_as(builtin::ptx_commit_group())) { print_extern_call_stmt("tl.cp_async_commit"); diff --git a/src/target/codegen_py.cc b/src/target/codegen_py.cc index aa12eef09..6e0b787bb 100644 --- a/src/target/codegen_py.cc +++ b/src/target/codegen_py.cc @@ -155,6 +155,7 @@ void CodeGenTileLangPY::ReserveKeywordsAsUnique_() { void CodeGenTileLangPY::PrintSSAAssign(const std::string &target, const std::string &src, DataType t) { + PrintIndent(); stream << target << " = " << RemoveOutermostParentheses(src) << "\n"; } diff --git a/tilelang/contrib/cutedsl/cpasync.py b/tilelang/contrib/cutedsl/cpasync.py index 6ddeb8933..c5a4742a1 100644 --- a/tilelang/contrib/cutedsl/cpasync.py +++ b/tilelang/contrib/cutedsl/cpasync.py @@ -18,7 +18,7 @@ BYTES_PER_POINTER = 8 -def cp_async_gs(size, dst, dst_offset, src, src_offset): +def cp_async_gs(size, dst, src): assert size in [16, 8, 4] # use CG (cache global) to by pass L1 when loading contiguous 128B. mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA @@ -34,13 +34,13 @@ def cp_async_gs(size, dst, dst_offset, src, src_offset): dst_ptr = dst else: raise ValueError(f"Invalid destination type: {type(dst)}") - cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode) + cp_async_shared_global(dst_ptr, src_ptr, size, mode) @cute.jit -def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond): +def cp_async_gs_conditional(size, dst, src, cond): if cond: - cp_async_gs(size, dst, dst_offset, src, src_offset) + cp_async_gs(size, dst, src) @dsl_user_op