-
Notifications
You must be signed in to change notification settings - Fork 445
[BugFix] Fix several bugs in CodeGen for CuTeDSL backend #1746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughThis PR simplifies the cp_async function signatures across the codebase by removing per-operand offset parameters. The ptx_cp_async codegen now uses a 3- or 4-argument format (dst, src, bytes, and optional predicate), and corresponding function signatures in the TileLang contrib module are updated. A minor indentation fix is also applied to SSA assignment emission. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/target/codegen_cutedsl.cc`:
- Around line 314-334: In the ptx_cp_async handler
(op->op.same_as(builtin::ptx_cp_async())), remove the unconditional
PrintIndent() before the if/else and instead call PrintIndent() once inside each
branch so indentation is emitted exactly once per path; i.e., call PrintIndent()
before stream << "tl.cp_async_gs(...)\n" in the 3-arg branch and call
PrintIndent() before stream << "tl.cp_async_gs_conditional(...)\n" in the 4-arg
branch to avoid double indentation and invalid Python output.
| if (op->op.same_as(builtin::ptx_cp_async())) { | ||
| // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, | ||
| // args[3] = predicate (optional) | ||
| ICHECK(op->args.size() == 3 || op->args.size() == 4) | ||
| << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " | ||
| "src_access_ptr, bytes, [predicate])"; | ||
|
|
||
| std::string dst = PrintExpr_(op->args[0]); | ||
| std::string dst_offset = PrintExpr_(op->args[1]); | ||
| std::string src = PrintExpr_(op->args[2]); | ||
| std::string src_offset = PrintExpr_(op->args[3]); | ||
| std::string size = PrintExpr_(op->args[4]); | ||
| // use size of argument list to indicate whether or not to use predicated | ||
| // cp.async | ||
| if (op->args.size() == 5) { | ||
| PrintIndent(); | ||
| stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset | ||
| << ", " << src << ", " << src_offset << ")\n"; | ||
| std::string src = PrintExpr_(op->args[1]); | ||
| std::string size = PrintExpr_(op->args[2]); | ||
|
|
||
| this->PrintIndent(); | ||
| if (op->args.size() == 3) { | ||
| stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src | ||
| << ")\n"; | ||
| } else { | ||
| std::string condition = PrintExpr_(op->args[5]); | ||
| std::string condition = PrintExpr_(op->args[3]); | ||
| PrintIndent(); | ||
| stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " | ||
| << dst_offset << ", " << src << ", " << src_offset << ", " | ||
| << condition << ")\n"; | ||
| << src << ", " << condition << ")\n"; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the actual file to confirm the code structure
cat -n src/target/codegen_cutedsl.cc | sed -n '310,340p'Repository: tile-ai/tilelang
Length of output: 1461
🏁 Script executed:
# Find PrintIndent() definition to understand its behavior
grep -n "PrintIndent()" src/target/codegen_cutedsl.cc | head -20Repository: tile-ai/tilelang
Length of output: 525
🏁 Script executed:
# Find the PrintIndent() method definition
grep -n "PrintIndent\(\)" src/target/codegen_cutedsl.cc -B 5 -A 2 | grep -E "(void|class|Print)" | head -10Repository: tile-ai/tilelang
Length of output: 439
🏁 Script executed:
# Search for PrintIndent method definition in the class
rg "void PrintIndent\(\)|void.*PrintIndent" src/target/codegen_cutedsl.cc -A 3Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for PrintIndent in header files
fd -t f -e h -e hpp | xargs grep -l "PrintIndent" | head -5Repository: tile-ai/tilelang
Length of output: 87
🏁 Script executed:
# Check the header file for PrintIndent definition
grep -n "PrintIndent" src/target/codegen_c_host.h -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 215
🏁 Script executed:
# Search more broadly for PrintIndent definition
rg "PrintIndent\s*\(\)" -A 2 src/target/ | grep -A 2 "void\|stream"Repository: tile-ai/tilelang
Length of output: 22158
🏁 Script executed:
# Find the base class of CodeGenTileLangCuTeDSL to understand PrintIndent
grep -n "class CodeGenTileLangCuTeDSL" src/target/codegen_cutedsl.cc -A 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for CodeGenTileLangCuTeDSL class definition
fd -t f -e h -e hpp | xargs grep -l "CodeGenTileLangCuTeDSL" | head -3Repository: tile-ai/tilelang
Length of output: 88
🏁 Script executed:
# Check the header file for class definition and PrintIndent
cat -n src/target/codegen_cutedsl.h | grep -E "class CodeGenTileLangCuTeDSL|PrintIndent" -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 215
🏁 Script executed:
# Find CodeGenTileLangPY to understand the inheritance chain
fd -t f -e h -e hpp | xargs grep -l "CodeGenTileLangPY" | head -3Repository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
# Check CodeGenTileLangPY to understand the inheritance
grep -n "class CodeGenTileLangPY" src/target/codegen_py.h -A 2Repository: tile-ai/tilelang
Length of output: 204
🏁 Script executed:
# Search for PrintIndent implementation in codegen files
rg "void.*PrintIndent\s*\(\)" src/target/ -A 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for PrintIndent definition more broadly
rg "PrintIndent" src/target/codegen_py.h -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search in all header files for PrintIndent definition
fd -t f -e h -e hpp | xargs grep "PrintIndent" | head -10Repository: tile-ai/tilelang
Length of output: 112
🏁 Script executed:
# Look for the base class that might have PrintIndent - check StmtFunctor or other bases
rg "PrintIndent" src/ -B 5 -A 2 | grep -E "(void|protected|public|private)" | head -20Repository: tile-ai/tilelang
Length of output: 1872
Remove the redundant PrintIndent() call in the 4‑arg ptx_cp_async branch.
PrintIndent() is called at line 325 before the if/else and again at line 331 within the else branch for the 4‑arg case. This emits two indentation prefixes, which produces invalid Python syntax. Move PrintIndent() into each branch so it is called exactly once per path.
🛠️ Proposed fix
- this->PrintIndent();
- if (op->args.size() == 3) {
- stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
- << ")\n";
- } else {
- std::string condition = PrintExpr_(op->args[3]);
- PrintIndent();
- stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
- << src << ", " << condition << ")\n";
- }
+ if (op->args.size() == 3) {
+ PrintIndent();
+ stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src
+ << ")\n";
+ } else {
+ std::string condition = PrintExpr_(op->args[3]);
+ PrintIndent();
+ stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", "
+ << src << ", " << condition << ")\n";
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (op->op.same_as(builtin::ptx_cp_async())) { | |
| // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, | |
| // args[3] = predicate (optional) | |
| ICHECK(op->args.size() == 3 || op->args.size() == 4) | |
| << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " | |
| "src_access_ptr, bytes, [predicate])"; | |
| std::string dst = PrintExpr_(op->args[0]); | |
| std::string dst_offset = PrintExpr_(op->args[1]); | |
| std::string src = PrintExpr_(op->args[2]); | |
| std::string src_offset = PrintExpr_(op->args[3]); | |
| std::string size = PrintExpr_(op->args[4]); | |
| // use size of argument list to indicate whether or not to use predicated | |
| // cp.async | |
| if (op->args.size() == 5) { | |
| PrintIndent(); | |
| stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << dst_offset | |
| << ", " << src << ", " << src_offset << ")\n"; | |
| std::string src = PrintExpr_(op->args[1]); | |
| std::string size = PrintExpr_(op->args[2]); | |
| this->PrintIndent(); | |
| if (op->args.size() == 3) { | |
| stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src | |
| << ")\n"; | |
| } else { | |
| std::string condition = PrintExpr_(op->args[5]); | |
| std::string condition = PrintExpr_(op->args[3]); | |
| PrintIndent(); | |
| stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " | |
| << dst_offset << ", " << src << ", " << src_offset << ", " | |
| << condition << ")\n"; | |
| << src << ", " << condition << ")\n"; | |
| } | |
| if (op->op.same_as(builtin::ptx_cp_async())) { | |
| // args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes, | |
| // args[3] = predicate (optional) | |
| ICHECK(op->args.size() == 3 || op->args.size() == 4) | |
| << "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, " | |
| "src_access_ptr, bytes, [predicate])"; | |
| std::string dst = PrintExpr_(op->args[0]); | |
| std::string src = PrintExpr_(op->args[1]); | |
| std::string size = PrintExpr_(op->args[2]); | |
| if (op->args.size() == 3) { | |
| PrintIndent(); | |
| stream << "tl.cp_async_gs(" << size << ", " << dst << ", " << src | |
| << ")\n"; | |
| } else { | |
| std::string condition = PrintExpr_(op->args[3]); | |
| PrintIndent(); | |
| stream << "tl.cp_async_gs_conditional(" << size << ", " << dst << ", " | |
| << src << ", " << condition << ")\n"; | |
| } |
🤖 Prompt for AI Agents
In `@src/target/codegen_cutedsl.cc` around lines 314 - 334, In the ptx_cp_async
handler (op->op.same_as(builtin::ptx_cp_async())), remove the unconditional
PrintIndent() before the if/else and instead call PrintIndent() once inside each
branch so indentation is emitted exactly once per path; i.e., call PrintIndent()
before stream << "tl.cp_async_gs(...)\n" in the 3-arg branch and call
PrintIndent() before stream << "tl.cp_async_gs_conditional(...)\n" in the 4-arg
branch to avoid double indentation and invalid Python output.
as titled
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.