-
Notifications
You must be signed in to change notification settings - Fork 333
[Refactor] Improve assertion handling in CodeGenCHost and ArgBinder #1352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughSimplify host assertion messages to handle single EQ comparisons; propagate NULL-awareness through ArgBinder and BindDLTensor with an Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Detector as UsedBufferDetector (pre-pass)
participant MakePackedAPI as MakePackedAPI pass
participant ArgBinder as ArgBinder / BindDLTensor
participant MergeIf as MergeIfStmtSubstitute
participant CodeGen as CodeGenCHost
Detector->>MakePackedAPI: build data_var2param / shape_var2params, mark used buffers
MakePackedAPI->>ArgBinder: call BindDLTensor(buffer, ..., arg_name, is_used)
ArgBinder->>ArgBinder: emit guarded assertions (nullable_guard → IfThenElse) or bind non-NULL path
MakePackedAPI->>MergeIf: invoke MergeIfStmtSubstitute on lowered PrimFunc
MergeIf->>MergeIf: flatten SeqStmt and merge consecutive IfThenElse with same condition
MergeIf->>MakePackedAPI: return transformed PrimFunc
MakePackedAPI->>CodeGen: hand off transformed PrimFunc
CodeGen->>CodeGen: emit host asserts (single‑EQ diagnostics) and call runtime dtype helpers as needed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/transform/merge_if_stmt.cc (2)
1-4: Fix incorrect file name in header comment.The file header comment incorrectly states
\file if_stmt_binding.ccbut the actual file ismerge_if_stmt.cc./*! - * \file if_stmt_binding.cc + * \file merge_if_stmt.cc * \brief Merge the If Stmt in SeqStmt */
72-77: Bodies are visited twice when merged into SeqStmt.When merging bodies into a
SeqStmt, the code callsthis->VisitStmt(SeqStmt(current_if_bodies)). However, thecurrent_if_bodiescontainthen_casestatements from already-processedif_nodes (line 68, 83). These statements come from theflat_seqwhich was already visited at line 53-54.This means the bodies are visited twice, which could cause issues with transformations that aren't idempotent.
Consider removing the redundant
VisitStmtcall since bodies were already visited:auto if_stmt = IfThenElse(current_condition, current_if_bodies.size() == 1 ? current_if_bodies[0] - : this->VisitStmt(SeqStmt(current_if_bodies)), + : SeqStmt(current_if_bodies), Stmt());Apply this change to all three occurrences (lines 76, 94, 109).
Also applies to: 90-95, 105-110
🧹 Nitpick comments (2)
src/transform/arg_binder.cc (1)
44-55: Consider whether the trailingEvaluate(0)inSeqStmtis necessary.The pattern
SeqStmt({check, Evaluate(0)})wraps the guarded assertion with a nop statement. While this works, it creates slightly larger IR trees. If this is intentional for pass compatibility reasons, a brief comment explaining why would be helpful.src/transform/merge_if_stmt.h (1)
40-40: Consider using const reference or value parameter.The function takes a non-const reference but also returns the modified
PrimFunc. This is an unusual API pattern - typically you'd either:
- Take by value/const-ref and return the modified copy, or
- Take by non-const ref, mutate in-place, and return void
The current signature allows mutation of the input while also returning it, which could be confusing.
Consider changing to:
-PrimFunc MergeIfStmtSubstitute(PrimFunc &f); +PrimFunc MergeIfStmtSubstitute(PrimFunc f);This would require updating the implementation in
merge_if_stmt.ccaccordingly.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/target/codegen_c_host.cc(1 hunks)src/transform/arg_binder.cc(16 hunks)src/transform/make_packed_api.cc(2 hunks)src/transform/merge_if_stmt.cc(4 hunks)src/transform/merge_if_stmt.h(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
Repo: tile-ai/tilelang PR: 794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/target/codegen_c_host.cc
🧬 Code graph analysis (3)
src/transform/merge_if_stmt.h (1)
src/transform/merge_if_stmt.cc (9)
MergeIfStmtRewriter(35-35)MergeIfStmtSubstitute(118-120)MergeIfStmtSubstitute(118-118)f(24-27)f(24-24)ApplyMergeIfStmt(122-122)ApplyMergeIfStmt(122-122)stmt(29-32)stmt(29-29)
src/transform/make_packed_api.cc (1)
src/transform/merge_if_stmt.cc (2)
MergeIfStmtSubstitute(118-120)MergeIfStmtSubstitute(118-118)
src/transform/merge_if_stmt.cc (1)
src/transform/if_stmt_binding.cc (2)
f(22-26)f(22-22)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (2)
src/target/codegen_c_host.cc (1)
358-383: LGTM! The simplified single-EQ assertion path is cleaner.The refactored code now handles only single equality checks directly, which is more straightforward than the previous multi-EQ collection approach. The 512-byte buffer is reasonable for typical assertion messages.
One minor observation: the format specifier
%lldexpectslong long, and you're casting to(long long)which is correct.src/transform/make_packed_api.cc (1)
469-470: LGTM! Integration of MergeIfStmt pass is correctly placed.The
MergeIfStmtSubstitutecall afterMakePackedAPIensures that the consecutive if-statements generated by the nullable guard pattern get merged, which should improve the generated C code quality.
| PrimExpr cond = value == arg; | ||
| BinderAddAssert(&analyzer_, cond, arg_name, &asserts_, nullable_guard); | ||
| } | ||
| // ICHECK(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove commented-out debug code.
The // ICHECK(false); appears to be leftover debug code that should be removed.
- // ICHECK(false);
return false;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // ICHECK(false); | |
| return false; |
🤖 Prompt for AI Agents
In src/transform/arg_binder.cc around line 142, remove the commented-out debug
line "// ICHECK(false);" because it is leftover debug code; delete the comment
so the file contains only active, meaningful code and no commented-out ICHECK
statement.
| BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset", | ||
| true, is_null); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential incorrect usage of BindNullable for offset factor check.
The call BindNullable(offset, truncmod(offset, factor), ...) appears to be checking that offset == truncmod(offset, factor), which would only pass when offset % factor == 0. However, the original logic should assert that truncmod(offset, factor) == 0. This seems like a logic error.
Consider this fix:
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
- BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset",
- true, is_null);
+ BinderAddAssert(&analyzer_, zero == truncmod(offset, factor),
+ arg_name + ".elem_offset", &asserts_, is_null);
}Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/transform/arg_binder.cc around lines 556 to 558, the BindNullable call
currently uses BindNullable(offset, truncmod(offset, factor), ...), which
effectively checks offset == truncmod(offset,factor) (true only when offset %
factor == 0); instead it should assert that truncmod(offset, factor) == 0.
Change the call so the first argument is truncmod(offset, factor) and the second
argument is 0 (preserving the same arg_name, nullable flag and is_null), i.e.
bind the truncated modulus and compare it to zero to correctly enforce
divisibility by factor.
| namespace tvm { | ||
| namespace tl { | ||
|
|
||
| using namespace tir; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Avoid using namespace in header files.
Placing using namespace tir; in a header pollutes the namespace for all translation units that include this header. This can cause name collisions and unexpected behavior.
Consider removing this and using qualified names in declarations:
-using namespace tir;
-
// Forward declaration
class MergeIfStmtRewriter;
...
-PrimFunc MergeIfStmtSubstitute(PrimFunc &f);
+tir::PrimFunc MergeIfStmtSubstitute(tir::PrimFunc &f);
-Stmt ApplyMergeIfStmt(Stmt stmt);
+tir::Stmt ApplyMergeIfStmt(tir::Stmt stmt);Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/transform/merge_if_stmt.h around line 14, remove the top-level 'using
namespace tir;' which pollutes global namespaces in all translation units;
instead, qualify all references to tir types/functions in this header with the
'tir::' prefix (or introduce a local namespace alias like 'namespace t = tir;'
within the header and use 't::' qualifiers), or move any convenience 'using'
directives into the corresponding .cpp file or into limited function scopes;
also add any needed forward declarations or includes so the fully-qualified
names resolve.
This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code.
…nding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
src/transform/arg_binder.cc (2)
142-142: Remove commented-out debug code.The
// ICHECK(false);is leftover debug code that should be removed.- // ICHECK(false); return false;
563-564: Incorrect usage of BindNullable for divisibility check.The call
BindNullable(offset, truncmod(offset, factor), ...)attempts to assert thatoffset == truncmod(offset, factor), which only holds whenoffset % factor == 0. However, this is the second binding ofoffset(it was already bound at lines 556-557), so this becomes a constraint check.The correct check should be
truncmod(offset, factor) == 0.- BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset", - true, is_null); + PrimExpr zero = make_zero(offset.dtype()); + BinderAddAssert(&analyzer_, zero == truncmod(offset, factor), + arg_name + ".elem_offset", &asserts_, is_null);Based on past review comments.
🧹 Nitpick comments (4)
src/transform/arg_binder.cc (2)
44-54: Simplify the NULL-guarded assertion wrapping.When
nullable_guardis defined, the assertion is wrapped inSeqStmt({check, Evaluate(0)}). This extra wrapping appears unnecessary sincecheckis already aStmt.Consider simplifying:
- Stmt check = AssertStmt(scond, StringImm(os.str()), Evaluate(0)); - check = IfThenElse(Not(nullable_guard), check); - asserts->emplace_back(SeqStmt({check, Evaluate(0)})); + Stmt check = IfThenElse(Not(nullable_guard), + AssertStmt(scond, StringImm(os.str()), Evaluate(0))); + asserts->emplace_back(check);This would make the code more concise and avoid the unnecessary
SeqStmtwrapper.
327-330: Consider simplifying the SeqStmt wrapping pattern.The pattern
SeqStmt({check, nop})appears multiple times throughout this function (lines 330, 409, 493, 546, 629). This wrapping seems unnecessary since the check is already aStmt.Consider simplifying to just:
- Stmt ndim_check = AssertStmt(a_ndim == v_ndim, msg, nop); - ndim_check = IfThenElse(Not(is_null), ndim_check); - init_nest_.emplace_back(SeqStmt({ndim_check, nop})); + Stmt ndim_check = IfThenElse(Not(is_null), + AssertStmt(a_ndim == v_ndim, msg, nop)); + init_nest_.emplace_back(ndim_check);This would make the code more concise and consistent.
docs/compiler_internals/tensor_checks.md (1)
74-114: Clarify nullability rules for symbolic runtime conditions.The nullability examples are helpful, but example #4 (lines 107-113) states that tensors are non-nullable when the condition is only known at runtime. This is consistent with conservative static analysis, but it might be helpful to explain why this design choice was made.
Consider adding a brief note:
4) Must be non-NULL (runtime condition) ```python @T.prim_func def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): if some_cond: A[0] = 1Since
some_condis only known at runtime, static analysis cannot proveAis unused;Ais thus non-nullable.
+
+Rationale: Conservative static analysis treats dynamically-gated accesses as "potentially used" to avoid runtime errors when the condition evaluates to true.</blockquote></details> <details> <summary>src/transform/make_packed_api.cc (1)</summary><blockquote> `546-546`: **Document why MergeIfStmtSubstitute is applied here.** `MergeIfStmtSubstitute(func)` is called after `MakePackedAPI` to merge consecutive `if` statements with identical conditions. This makes sense given that NULL-guarded assertions introduce many conditional checks. However, it's not immediately obvious why this happens inside the MakePackedAPI transformation rather than as a separate pass in the pipeline. Consider adding a comment: ```diff func = MakePackedAPI(std::move(func)); + // Merge consecutive if-statements with identical NULL-guards to simplify the generated code func = MergeIfStmtSubstitute(func);Alternatively, could this be applied as a standalone pass in phase.py instead?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
docs/compiler_internals/tensor_checks.md(1 hunks)docs/index.md(1 hunks)src/transform/arg_binder.cc(18 hunks)src/transform/arg_binder.h(1 hunks)src/transform/make_packed_api.cc(4 hunks)tilelang/engine/phase.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- docs/index.md
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/engine/phase.py (2)
src/transform/simplify.cc (2)
Simplify(530-538)Simplify(530-530)tilelang/transform/simplify.py (1)
Simplify(20-28)
src/transform/make_packed_api.cc (1)
src/transform/merge_if_stmt.cc (4)
op(47-115)op(47-47)MergeIfStmtSubstitute(118-120)MergeIfStmtSubstitute(118-118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (12)
tilelang/engine/phase.py (1)
228-228: Clarify the necessity of this additional Simplify pass.An additional
Simplify()pass is added immediately afterMakePackedAPI(). While this may help fold redundant conditions introduced by the NULL-safety transformations, it's not clear whether this pass is essential or optional.Please clarify:
- Does this Simplify pass address specific patterns introduced by MakePackedAPI (e.g., redundant NULL checks or guards)?
- Have you measured the impact on compile time vs. code quality?
- Could this be combined with the existing Simplify calls at lines 183 and 188, or is the placement after MakePackedAPI critical?
Consider documenting the rationale inline:
mod = tilelang.transform.MakePackedAPI()(mod) +# Simplify to fold redundant NULL-guards and conditions introduced by MakePackedAPI mod = tilelang.transform.Simplify()(mod)src/transform/arg_binder.cc (5)
146-168: LGTM: Bind_ correctly maintains non-nullable assertions.The
Bind_function correctly continues to emit unconditional assertions without nullable guards, which is appropriate for required (non-nullable) bindings.
277-299: LGTM: is_used parameter correctly enforces NULL-safety.The
is_usedparameter appropriately distinguishes between buffers that must be non-NULL (used by the function) and those that may be NULL (unused). The logic correctly enforces a non-NULL assertion whenis_usedis true.
584-600: LGTM: Enhanced device type error messages.The device type check is correctly guarded by
is_null, and the enhanced error message with the DLPack device code legend (line 594-595) will help users diagnose device mismatch issues more easily.
622-629: LGTM: Data pointer NULL check correctly handles edge cases.The data pointer check correctly:
- Allows NULL data pointers for size-0 arrays (
alloc_size == 0)- Guards the check with
is_nullto skip when the handle itself is NULL- Provides a clear error message
439-448: I encountered a repository access issue. To verify this review comment about the potential NULL pointer dereference, I'll need additional information from you. Could you please provide:
The full context of lines 400-450 from
src/transform/arg_binder.cc(including the sections mentioned: lines 406-409 for dtype check, lines 420-426 for buf_shape initialization, and lines 439-448 for the shape binding)The implementation of
BindNullablefunction to understand how it guards nullable operationsUsage examples of similar patterns in the codebase where nullable buffer operations are performed
Alternatively, if you have access to the repository, you could share the output of:
sed -n '400,450p' src/transform/arg_binder.ccOnce you provide this context, I'll be able to verify whether the NULL pointer dereference concern is valid and whether the proposed fix is necessary.
src/transform/arg_binder.h (1)
106-108: LGTM: Signature change for is_used parameter.The addition of the
bool is_usedparameter toBindDLTensoris consistent with the implementation inarg_binder.ccand enables NULL-safe buffer handling.docs/compiler_internals/tensor_checks.md (2)
1-56: LGTM: Well-structured documentation introduction.The introduction clearly explains:
- Why host-side checks are performed (ABI stability, low overhead, focused errors)
- How to inspect the generated host source for debugging
- What checks are performed
The documentation structure is logical and easy to follow.
236-381: LGTM: Comprehensive troubleshooting examples.The minimal reproduction examples for each error type are extremely valuable for users debugging issues. The examples cover all major error scenarios:
- Argument count/type mismatches
- Shape/stride/dtype mismatches
- Device mismatches
- NULL pointer issues
These will significantly reduce support burden.
src/transform/make_packed_api.cc (3)
301-325: LGTM: Buffer-to-parameter mapping enables usage tracking.The data_var2param and shape_var2params mappings correctly establish the relationship between:
- Buffer data variables and their parameter variables
- Buffer shape/stride variables and their parameter variables
This enables accurate detection of which buffers are actually used by the function body, supporting NULL-safe binding for unused buffers.
467-470: LGTM: is_used flag correctly propagated.The
used_param_buffers.count(var.get())check correctly determines whether each buffer is used, passing the appropriate boolean toBindDLTensor. This enables NULL-safe handling of unused buffers.
329-366: Based on my analysis of the code and the TVM visitor pattern semantics, I can now provide a rewritten review comment:
The
UsedBufferDetectorcorrectly handles buffer usage through the recursive visitor pattern.The implementation uses
StmtExprVisitor, which recursively traverses the entire expression and statement tree. This means:
- Variables in Let bindings are caught by
VisitExpr_(const VarNode*)during recursive traversal- Variables captured in closures or passed to functions are similarly visited as
VarNodeinstancesBufferStoreandBufferLoadoperations explicitly extract the underlyingbuffer->datareferenceThe three detection methods (direct variables,
BufferStore,BufferLoad) are sufficient for the visitor's purpose. However, ifAddressOfoperations exist in the codebase that take buffer addresses without explicitBufferLoad/BufferStorecontext, those might not be detected—verify this specific pattern if it's used in the codebase.
… performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (8)
examples/quickstart.py (1)
52-55: Large matrix dimensions may be excessive for a quickstart example.The matrix dimensions (16384×16384×16384) are very large for a quickstart example. This may cause:
- Long execution times for users trying the example
- High memory usage (~1.5GB+ for the tensors)
Consider using smaller dimensions (e.g., 1024 or 4096) for the example, or add a comment explaining why large sizes are needed.
src/runtime/error_helpers.cc (1)
1-3: Consider adding a license header.This file is missing the Apache 2.0 license header that appears in other files in this project (e.g.,
make_packed_api.cc). For consistency and compliance, consider adding it.src/transform/make_packed_api.cc (1)
399-404: Consider usingstrlen(suffix)or a named constant instead of the magic number 7.The hardcoded
7for the"_handle"suffix length is fragile if the suffix changes.- const char *suffix = "_handle"; - if (display_name.size() >= 7 && - display_name.compare(display_name.size() - 7, 7, suffix) == 0) { - display_name.erase(display_name.size() - 7); + constexpr std::string_view suffix = "_handle"; + if (display_name.size() >= suffix.size() && + display_name.compare(display_name.size() - suffix.size(), suffix.size(), suffix) == 0) { + display_name.erase(display_name.size() - suffix.size()); }maint/host_checks/06_strides_mismatch.py (1)
11-15: Consider clarifying the shape after transpose.The transpose
a.t()changes shape from(M, K)to(K, M), which may also trigger shape-related validation in addition to stride checks. If the intent is purely to test strides validation with non-contiguous memory, you could add a comment noting that both shape and stride mismatches may contribute to the error—or use a different technique to create non-contiguous memory while preserving the expected shape.That said, this is sufficient for demonstrating strides-related failure behavior.
maint/host_checks/run_all.py (2)
44-50: Consider distinguishing environment/setup failures from “PASS” reprosLines 44–50 treat any non-zero return code as
PASS, which will also classify environment/setup errors (e.g., missing CUDA, import errors) as successful repros, even if the host checks were never exercised. If you expect this script to run in heterogeneous or partially misconfigured environments (e.g., CPU-only machines), consider special-casing well-known environment failures (like the"CUDA is not available; cannot build CUDA kernel for host-check repros."raised inmaint/host_checks/common.py) asSKIPor a separateENV_FAILstatus so the summary better reflects what was actually tested.
66-71: Letmain()return an exit code and only callsys.exitin the CLI entrypointHaving
main()callsys.exitdirectly makes it harder to re-use or unit-test. A small refactor letsmain()return an int and keeps the CLI behavior identical:def main(): @@ - # Exit non-zero if any FAIL - sys.exit(1 if counts.get("FAIL", 0) else 0) + # Return non-zero if any FAIL + return 1 if counts.get("FAIL", 0) else 0 @@ if __name__ == "__main__": - main() + raise SystemExit(main())This preserves the current exit codes when run as a script but gives you a pure callable for tests and tooling.
maint/host_checks/common.py (2)
35-50: Align CUDA-availability guarding between matmul and scalar-check kernels
build_matmul_kernel(Line 37) guards CUDA targets withtorch.cuda.is_available(), butbuild_scalar_check_kernel(Lines 44–50) will still attempttilelang.compile(..., target="cuda")even when CUDA is unavailable. For consistency and clearer failures when runningmaint/host_checks/10_scalar_type_mismatch.py, consider factoring the guard into a shared helper and reusing it:import tilelang import tilelang.language as T import torch +def _ensure_cuda_available(target: str) -> None: + if target.startswith("cuda") and not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available; cannot build CUDA kernel for host-check repros." + ) + + def make_matmul_prim(M, @@ def build_matmul_kernel(M=1024, N=1024, K=1024, target="cuda"): """Compile and return a callable kernel that takes (A, B) and returns C.""" - if target.startswith("cuda") and not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available; cannot build CUDA kernel for host-check repros.") + _ensure_cuda_available(target) @@ def build_scalar_check_kernel(target="cuda"): - - @T.prim_func + _ensure_cuda_available(target) + + @T.prim_func def scalar_check(x: T.int32, flag: T.bool()): T.evaluate(0)This keeps behavior identical for matmul kernels and makes the scalar-check path fail fast with the same, more informative message on non-CUDA hosts.
46-49: Silence Ruff ARG001 for intentionally unusedscalar_checkparametersStatic analysis (Ruff ARG001) correctly notes that
xandflaginscalar_checkare unused, but they are required by the prim_func signature so the host-side type checker can fire before running the body. To document that this is intentional and keep linters quiet, you can add anoqaon the def line:- @T.prim_func - def scalar_check(x: T.int32, flag: T.bool()): + @T.prim_func + def scalar_check(x: T.int32, flag: T.bool()): # noqa: ARG001 T.evaluate(0)This avoids changing kernel behavior while making the intent explicit. Based on static analysis hints.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
.gitignore(1 hunks)examples/quickstart.py(3 hunks)maint/host_checks/01_num_args_mismatch.py(1 hunks)maint/host_checks/02_pointer_type_error.py(1 hunks)maint/host_checks/03_ndim_mismatch.py(1 hunks)maint/host_checks/04_dtype_mismatch.py(1 hunks)maint/host_checks/05_shape_mismatch.py(1 hunks)maint/host_checks/06_strides_mismatch.py(1 hunks)maint/host_checks/07_device_type_mismatch.py(1 hunks)maint/host_checks/08_device_id_mismatch.py(1 hunks)maint/host_checks/09_null_data_pointer.py(1 hunks)maint/host_checks/10_scalar_type_mismatch.py(1 hunks)maint/host_checks/README.md(1 hunks)maint/host_checks/common.py(1 hunks)maint/host_checks/run_all.py(1 hunks)src/runtime/error_helpers.cc(1 hunks)src/transform/make_packed_api.cc(8 hunks)tilelang/jit/adapter/tvm_ffi.py(0 hunks)
💤 Files with no reviewable changes (1)
- tilelang/jit/adapter/tvm_ffi.py
✅ Files skipped from review due to trivial changes (2)
- .gitignore
- maint/host_checks/README.md
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
examples/quickstart.py
🧬 Code graph analysis (12)
maint/host_checks/10_scalar_type_mismatch.py (2)
maint/host_checks/common.py (2)
build_scalar_check_kernel(44-50)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)
maint/host_checks/06_strides_mismatch.py (3)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)tilelang/language/ast/ir.py (1)
target(1682-1713)
maint/host_checks/04_dtype_mismatch.py (2)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)tilelang/jit/adapter/tvm_ffi.py (1)
get_host_source(282-286)
examples/quickstart.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)tilelang/env.py (1)
disable_cache(275-276)tilelang/jit/adapter/base.py (1)
get_kernel_source(93-97)
maint/host_checks/03_ndim_mismatch.py (5)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)maint/host_checks/01_num_args_mismatch.py (1)
main(10-17)maint/host_checks/05_shape_mismatch.py (1)
main(7-15)maint/host_checks/02_pointer_type_error.py (1)
main(10-18)
src/transform/make_packed_api.cc (1)
src/transform/merge_if_stmt.cc (4)
op(47-115)op(47-47)MergeIfStmtSubstitute(118-120)MergeIfStmtSubstitute(118-118)
maint/host_checks/01_num_args_mismatch.py (2)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)
maint/host_checks/09_null_data_pointer.py (4)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)maint/host_checks/01_num_args_mismatch.py (1)
main(10-17)maint/host_checks/02_pointer_type_error.py (1)
main(10-18)
maint/host_checks/07_device_type_mismatch.py (2)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)
maint/host_checks/02_pointer_type_error.py (2)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)
maint/host_checks/common.py (15)
maint/host_checks/01_num_args_mismatch.py (1)
main(10-17)maint/host_checks/03_ndim_mismatch.py (1)
main(7-15)maint/host_checks/04_dtype_mismatch.py (1)
main(7-15)maint/host_checks/05_shape_mismatch.py (1)
main(7-15)maint/host_checks/06_strides_mismatch.py (1)
main(7-15)maint/host_checks/07_device_type_mismatch.py (1)
main(7-14)maint/host_checks/08_device_id_mismatch.py (1)
main(7-21)maint/host_checks/10_scalar_type_mismatch.py (1)
main(6-11)maint/host_checks/02_pointer_type_error.py (1)
main(10-18)maint/host_checks/09_null_data_pointer.py (1)
main(14-21)tilelang/language/kernel.py (1)
threads(214-218)tilelang/language/allocate.py (2)
alloc_shared(28-43)alloc_fragment(60-71)tilelang/language/loop.py (1)
Pipelined(56-93)tilelang/language/ast/ir.py (2)
target(1682-1713)evaluate(1319-1331)tilelang/jit/kernel.py (1)
out_idx(598-599)
maint/host_checks/08_device_id_mismatch.py (2)
maint/host_checks/common.py (2)
build_matmul_kernel(35-41)main(16-30)maint/host_checks/run_all.py (1)
main(6-67)
🪛 Ruff (0.14.6)
maint/host_checks/run_all.py
33-33: subprocess call: check for execution of untrusted input
(S603)
maint/host_checks/common.py
38-38: Avoid specifying long messages outside the exception class
(TRY003)
47-47: Unused function argument: x
(ARG001)
47-47: Unused function argument: flag
(ARG001)
maint/host_checks/08_device_id_mismatch.py
9-9: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (22)
examples/quickstart.py (3)
8-12: LGTM on the pass configuration setup.The pass_configs dictionary correctly uses
PassConfigKeyenum values to disable TMA lowering and warp specialization.
34-34: Settingnum_stages=0disables pipelining.For a quickstart example intended to showcase TileLang performance,
num_stages=0disables the pipelining optimization. If this is intentional for testing the assertion handling changes, consider adding a comment explaining why, or restore a non-zero value for the example.
62-62: Good additions for demonstrating kernel introspection and profiling.Printing the host source (line 62) and kernel source (lines 92-93) provides useful visibility into the generated code. The latency profiling (lines 96-100) is a good demonstration of TileLang's profiling capabilities.
Also applies to: 92-100
src/runtime/error_helpers.cc (3)
16-32: LGTM – clean error helper implementation.The function correctly constructs
DataTypeobjects from the packed arguments and produces a clear, actionable error message with kernel name, buffer name, and expected vs. actual dtypes. The return value of-1signals failure to the caller as intended.
36-49: LGTM – consistent variant without names.This follows the same pattern as
DTypeMismatchand provides a lightweight alternative when buffer/kernel names are not available.
54-60: LGTM – FFI registration looks correct.The static init block properly registers both error helpers with descriptive names.
src/transform/make_packed_api.cc (5)
42-42: LGTM – new include for merge_if_stmt transformation.The include is appropriately placed with other local includes.
301-325: LGTM – well-structured usage tracking infrastructure.The reverse mapping from buffer data vars to parameters, along with the shape/stride variable tracking, enables precise detection of buffer usage. The
PostOrderVisitapproach correctly captures all variable references within buffer metadata.
329-366: LGTM – UsedBufferDetector correctly tracks buffer usage.The visitor properly handles:
- Direct var references via
VisitExpr_(VarNode*)- Buffer stores and loads via dedicated visit methods
- Transitive usage through shape/stride variables
One minor note: calling
StmtExprVisitor::VisitExpr_andStmtExprVisitor::VisitStmt_after marking the buffer ensures child nodes are also visited.
482-488: LGTM – BindDLTensor now receives display name and is_used flag.This change enables nullable handling for unused buffers, improving assertion precision. The display name uses the buffer's data var name for clearer diagnostics.
563-564: I attempted to verify the review comment but encountered a repository access issue. The codebase cannot be cloned at this time, which prevents me from:
- Checking the header files included in
src/transform/make_packed_api.cc- Locating the declaration of
MergeIfStmtSubstitute- Verifying the function signature matches the usage at lines 563-564
- Confirming the function is properly declared before use
Unable to verify
MergeIfStmtSubstitutefunction declaration due to repository access constraints.The review comment's verification request cannot be completed without access to the codebase. Manual verification is required to confirm that
MergeIfStmtSubstituteis correctly declared in the included headers and that the function signature is compatible with its usage at lines 563-564.maint/host_checks/07_device_type_mismatch.py (1)
1-18: LGTM – clear repro script for device-type mismatch.The script correctly demonstrates the device-type mismatch scenario by passing CPU tensors to a CUDA kernel. CUDA availability is handled upstream in
build_matmul_kernel(percommon.pylines 36-37), which raisesRuntimeErrorif CUDA is unavailable.maint/host_checks/10_scalar_type_mismatch.py (1)
9-11: Only the first mismatch will be reproduced.The first call
fn(1.0, True)will likely raise an exception, preventing the second callfn(1, 2.5)from executing. If both scenarios need to be tested independently, consider separating them into distinct scripts or wrapping each in a try-except block.If only demonstrating one mismatch case is sufficient for this repro script, this is fine as-is.
maint/host_checks/04_dtype_mismatch.py (1)
1-19: LGTM – clear repro script for dtype mismatch.The script correctly demonstrates a dtype mismatch by providing
float32tensorawhenfloat16is expected. Theprint(fn.get_host_source())call is helpful for debugging the generated host code. Note that this script requires CUDA to be available (handled bybuild_matmul_kernelincommon.py).maint/host_checks/02_pointer_type_error.py (1)
1-22: LGTM!The repro script is well-structured and correctly demonstrates passing an incorrect type (int instead of tensor) to trigger the expected pointer-type assertion error. The docstring clearly describes the expected behavior.
maint/host_checks/05_shape_mismatch.py (1)
1-19: LGTM!The script correctly reproduces a shape mismatch scenario by constructing tensor
awith dimensionK+1instead of the expectedK. The implementation follows the established pattern for host-check scripts.maint/host_checks/09_null_data_pointer.py (1)
1-25: LGTM!The script appropriately documents the distinction between passing Python
Noneversus a true DLTensor with NULL data, and correctly reproduces the intended class of pointer validation errors. The detailed docstring adds valuable context.maint/host_checks/01_num_args_mismatch.py (1)
1-21: LGTM!The script correctly reproduces an argument count mismatch by omitting the second input tensor. The docstring and inline comments clearly explain the expected behavior and that the error occurs at the adapter level before host entry.
maint/host_checks/03_ndim_mismatch.py (1)
1-19: LGTM!The script correctly reproduces an ndim mismatch by constructing tensor
awith shape(M, K, 1)(rank 3) instead of the expected(M, K)(rank 2). Clear and follows the established pattern.maint/host_checks/08_device_id_mismatch.py (2)
7-12: Good handling of multi-GPU requirement.The script properly checks for CUDA availability and device count before proceeding, using the
[SKIP]pattern thatrun_all.pyrecognizes. This ensures graceful handling on single-GPU systems.
17-21: LGTM!The device ID mismatch is correctly reproduced by placing tensor
aoncuda:0and tensorboncuda:1. This will trigger the expected host-side device validation error.maint/host_checks/common.py (1)
6-32: Matmul prim and tiling pattern look consistent (LGTM)The
make_matmul_primconstruction (tiledA_shared/B_shared,C_localfragment,T.Pipelinedoverko, andT.gemmfollowed by a finalT.copyintoC) is coherent and matches the typical tilelang GEMM pattern. No functional issues stand out here.
This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/jit/test_tilelang_jit_nullptr.py (1)
23-25: Align B tensor shape annotation with its runtime layout
Bis declared asT.Tensor((K, N), dtype)(Line 23) but instantiated in PyTorch with shape(N, K)(Line 53) and tiled as(block_N, block_K), i.e.(N, K). This only works silently here becauseN == K; for general shapes it becomes misleading and could confuse host-side checks.Recommend updating the annotation to match actual usage:
- B: T.Tensor((K, N), dtype), + B: T.Tensor((N, K), dtype),so the prim_func signature, tiling, and test inputs stay consistent.
Also applies to: 53-57
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/jit/test_tilelang_jit_nullptr.py(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-14T07:56:11.098Z
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
Applied to files:
testing/python/jit/test_tilelang_jit_nullptr.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
testing/python/jit/test_tilelang_jit_nullptr.py (1)
10-18: JIT wrapper signature andwith_biasflag are consistent with the kernel bodyCapturing
with_biasin the closure and specializing viatensor_null_test(..., with_bias=False)matches the intended nullptr repro; the parameter ordering and defaults look good.
This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.
Summary by CodeRabbit
Bug Fixes
Performance
Improvements
Documentation
Tools
✏️ Tip: You can customize this high-level summary in your review settings.