Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/target/codegen_cutedsl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,25 @@ void CodeGenTileLangCuTeDSL::VisitExpr_(const CallNode *op,
};

if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";

std::string dst = PrintExpr_(op->args[0]);
std::string dst_offset = PrintExpr_(op->args[1]);
std::string src = PrintExpr_(op->args[2]);
std::string src_offset = PrintExpr_(op->args[3]);
std::string size = PrintExpr_(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
PrintIndent();
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset
<< ", " << src << ", " << src_offset << ")\n";
std::string src = PrintExpr_(op->args[1]);
std::string size = PrintExpr_(op->args[2]);

this->PrintIndent();
if (op->args.size() == 3) {
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
<< ")\n";
} else {
std::string condition = PrintExpr_(op->args[5]);
std::string condition = PrintExpr_(op->args[3]);
PrintIndent();
stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
<< dst_offset << ", " << src << ", " << src_offset << ", "
<< condition << ")\n";
<< src << ", " << condition << ")\n";
}
Comment on lines 314 to 334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's look at the actual file to confirm the code structure
cat -n src/target/codegen_cutedsl.cc | sed -n '310,340p'

Repository: tile-ai/tilelang

Length of output: 1461


🏁 Script executed:

# Find PrintIndent() definition to understand its behavior
grep -n "PrintIndent()" src/target/codegen_cutedsl.cc | head -20

Repository: tile-ai/tilelang

Length of output: 525


🏁 Script executed:

# Find the PrintIndent() method definition
grep -n "PrintIndent\(\)" src/target/codegen_cutedsl.cc -B 5 -A 2 | grep -E "(void|class|Print)" | head -10

Repository: tile-ai/tilelang

Length of output: 439


🏁 Script executed:

# Search for PrintIndent method definition in the class
rg "void PrintIndent\(\)|void.*PrintIndent" src/target/codegen_cutedsl.cc -A 3

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for PrintIndent in header files
fd -t f -e h -e hpp | xargs grep -l "PrintIndent" | head -5

Repository: tile-ai/tilelang

Length of output: 87


🏁 Script executed:

# Check the header file for PrintIndent definition
grep -n "PrintIndent" src/target/codegen_c_host.h -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 215


🏁 Script executed:

# Search more broadly for PrintIndent definition
rg "PrintIndent\s*\(\)" -A 2 src/target/ | grep -A 2 "void\|stream"

Repository: tile-ai/tilelang

Length of output: 22158


🏁 Script executed:

# Find the base class of CodeGenTileLangCuTeDSL to understand PrintIndent
grep -n "class CodeGenTileLangCuTeDSL" src/target/codegen_cutedsl.cc -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for CodeGenTileLangCuTeDSL class definition
fd -t f -e h -e hpp | xargs grep -l "CodeGenTileLangCuTeDSL" | head -3

Repository: tile-ai/tilelang

Length of output: 88


🏁 Script executed:

# Check the header file for class definition and PrintIndent
cat -n src/target/codegen_cutedsl.h | grep -E "class CodeGenTileLangCuTeDSL|PrintIndent" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 215


🏁 Script executed:

# Find CodeGenTileLangPY to understand the inheritance chain
fd -t f -e h -e hpp | xargs grep -l "CodeGenTileLangPY" | head -3

Repository: tile-ai/tilelang

Length of output: 112


🏁 Script executed:

# Check CodeGenTileLangPY to understand the inheritance
grep -n "class CodeGenTileLangPY" src/target/codegen_py.h -A 2

Repository: tile-ai/tilelang

Length of output: 204


🏁 Script executed:

# Search for PrintIndent implementation in codegen files
rg "void.*PrintIndent\s*\(\)" src/target/ -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for PrintIndent definition more broadly
rg "PrintIndent" src/target/codegen_py.h -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search in all header files for PrintIndent definition
fd -t f -e h -e hpp | xargs grep "PrintIndent" | head -10

Repository: tile-ai/tilelang

Length of output: 112


🏁 Script executed:

# Look for the base class that might have PrintIndent - check StmtFunctor or other bases
rg "PrintIndent" src/ -B 5 -A 2 | grep -E "(void|protected|public|private)" | head -20

Repository: tile-ai/tilelang

Length of output: 1872


Remove the redundant PrintIndent() call in the 4‑arg ptx_cp_async branch.

PrintIndent() is called at line 325 before the if/else and again at line 331 within the else branch for the 4‑arg case. This emits two indentation prefixes, which produces invalid Python syntax. Move PrintIndent() into each branch so it is called exactly once per path.

🛠️ Proposed fix
-    this->PrintIndent();
-    if (op->args.size() == 3) {
-      stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
-             << ")\n";
-    } else {
-      std::string condition = PrintExpr_(op->args[3]);
-      PrintIndent();
-      stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
-             << src << ", " << condition << ")\n";
-    }
+    if (op->args.size() == 3) {
+      PrintIndent();
+      stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
+             << ")\n";
+    } else {
+      std::string condition = PrintExpr_(op->args[3]);
+      PrintIndent();
+      stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
+             << src << ", " << condition << ")\n";
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";
std::string dst = PrintExpr_(op->args[0]);
std::string dst_offset = PrintExpr_(op->args[1]);
std::string src = PrintExpr_(op->args[2]);
std::string src_offset = PrintExpr_(op->args[3]);
std::string size = PrintExpr_(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
PrintIndent();
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset
<< ", " << src << ", " << src_offset << ")\n";
std::string src = PrintExpr_(op->args[1]);
std::string size = PrintExpr_(op->args[2]);
this->PrintIndent();
if (op->args.size() == 3) {
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
<< ")\n";
} else {
std::string condition = PrintExpr_(op->args[5]);
std::string condition = PrintExpr_(op->args[3]);
PrintIndent();
stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
<< dst_offset << ", " << src << ", " << src_offset << ", "
<< condition << ")\n";
<< src << ", " << condition << ")\n";
}
if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";
std::string dst = PrintExpr_(op->args[0]);
std::string src = PrintExpr_(op->args[1]);
std::string size = PrintExpr_(op->args[2]);
if (op->args.size() == 3) {
PrintIndent();
stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
<< ")\n";
} else {
std::string condition = PrintExpr_(op->args[3]);
PrintIndent();
stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
<< src << ", " << condition << ")\n";
}
🤖 Prompt for AI Agents
In `@src/target/codegen_cutedsl.cc` around lines 314 - 334, In the ptx_cp_async
handler (op->op.same_as(builtin::ptx_cp_async())), remove the unconditional
PrintIndent() before the if/else and instead call PrintIndent() once inside each
branch so indentation is emitted exactly once per path; i.e., call PrintIndent()
before stream << "tl.cp_async_gs(...)\n" in the 3-arg branch and call
PrintIndent() before stream << "tl.cp_async_gs_conditional(...)\n" in the 4-arg
branch to avoid double indentation and invalid Python output.

} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl.cp_async_commit");
Expand Down
1 change: 1 addition & 0 deletions src/target/codegen_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ void CodeGenTileLangPY::ReserveKeywordsAsUnique_() {

void CodeGenTileLangPY::PrintSSAAssign(const std::string &target,
const std::string &src, DataType t) {
PrintIndent();
stream << target << " = " << RemoveOutermostParentheses(src) << "\n";
}

Expand Down
8 changes: 4 additions & 4 deletions tilelang/contrib/cutedsl/cpasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
BYTES_PER_POINTER = 8


def cp_async_gs(size, dst, dst_offset, src, src_offset):
def cp_async_gs(size, dst, src):
assert size in [16, 8, 4]
# use CG (cache global) to by pass L1 when loading contiguous 128B.
mode = nvvm.LoadCacheModifierKind.CG if size == 16 else nvvm.LoadCacheModifierKind.CA
Expand All @@ -34,13 +34,13 @@ def cp_async_gs(size, dst, dst_offset, src, src_offset):
dst_ptr = dst
else:
raise ValueError(f"Invalid destination type: {type(dst)}")
cp_async_shared_global(dst_ptr + dst_offset, src_ptr + src_offset, size, mode)
cp_async_shared_global(dst_ptr, src_ptr, size, mode)


@cute.jit
def cp_async_gs_conditional(size, dst, dst_offset, src, src_offset, cond):
def cp_async_gs_conditional(size, dst, src, cond):
if cond:
cp_async_gs(size, dst, dst_offset, src, src_offset)
cp_async_gs(size, dst, src)


@dsl_user_op
Expand Down
Loading