-
Notifications
You must be signed in to change notification settings - Fork 332
[Bugfix] Resolve mixed stride dtype issue (inconsistent int32/int64 values) #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds a new ArgBinder utility implementing argument-to-value binding and validations (arrays, Buffer, DLTensor), updates vectorization checks to use dtype-consistent size constants, and adjusts an include path and clang-tidy header filter. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CallSite as Call Site
participant ArgBinder as ArgBinder
participant Analyzer as arith::Analyzer
participant AST as AST Builder
Note over CallSite,ArgBinder: Binding request (arg, value, name)
CallSite->>ArgBinder: Bind / BindArray / BindBuffer / BindDLTensor
ArgBinder->>Analyzer: Simplify conditions / compute simplified exprs
Analyzer-->>ArgBinder: simplified exprs
opt validation checks
ArgBinder->>ArgBinder: generate assertions (BinderAddAssert)
end
ArgBinder->>AST: emit LetStmt / DeclBuffer / AssertStmt / IfThenElse
AST-->>ArgBinder: nested init stmts
ArgBinder-->>CallSite: defs, asserts, init_nest
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
src/transform/arg_binder.h (2)
88-97: Docs: clarify fuzzy_match semantics.Implementation allows value to have extra leading 1-dims compared to arg (value.rank >= arg.rank). Update comment accordingly.
- * \param fuzzy_match If enabled, we allow value's dimension to be smaller - * than arg, as long as arg's higher dimensions are of 1. + * \param fuzzy_match If enabled, allow value to have extra leading dimensions of size 1 + * (i.e., value.rank >= arg.rank, with value.shape[0:diff] == 1).
112-126: Docs: minor typos.Fix “statemtn” -> “statement”, “Intializing” -> “Initializing”.
src/transform/arg_binder.cc (1)
40-51: Error text polish (optional)."Bind have an unmet assertion" → "Binding has an unmet assertion" for clearer logs. Low priority.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/transform/arg_binder.cc(1 hunks)src/transform/arg_binder.h(1 hunks)src/transform/loop_vectorize.cc(1 hunks)src/transform/make_packed_api.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/transform/arg_binder.cc (2)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)tilelang/language/tir/op.py (3)
truncmod(3047-3070)isnullptr(2649-2665)if_then_else(2907-2937)
src/transform/arg_binder.h (1)
src/transform/arg_binder.cc (10)
Bind(78-81)Bind(78-79)BindArray(83-93)BindArray(83-85)BindBuffer(95-156)BindBuffer(95-96)BindDLTensor(163-373)BindDLTensor(163-165)Bind_(54-76)Bind_(54-55)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
src/transform/make_packed_api.cc (1)
39-39: Include path consistency verified across the codebase.The verification confirms that all references to
arg_binder.huse the consistent local include path (#include "arg_binder.h"). Botharg_binder.ccandmake_packed_api.ccfollow the same pattern, and no stale old-style includes remain in the repository.
| * \brief Helper utility to generate match and bind of arguments. | ||
| * | ||
| * \note There is many places in TVM IR where we need argument bindings. | ||
| * | ||
| * Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)). | ||
| * Here n is a undefined variable that is decided by the outside, tB imposes | ||
| * a constraint such that it can only take tensor with shape 3, tC imposes | ||
| * another constraint that it's shape must equals n + 2. | ||
| * So if we call it with f(bufferA, bufferB, bufferC), we need to generate | ||
| * the following binding sequence: | ||
| * - define n = bufferA.shape[0] | ||
| * - assert bufferB.shape[0] == 3 | ||
| * - assert bufferB.shape[1] == n + 3 | ||
| * | ||
| * In general, this is a constraint solving problem. We have simplified | ||
| * assumption over the binding declaration, such that we require the variable | ||
| * occurred in constraint must be declared in argument list. So it is illegal to | ||
| * have signature f(tA(shape=(n+3))) without any argument variable corresponds | ||
| * to n, even though it is already enough to derive n from the input argument. | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docs: fix example typos and mismatches.
Several doc issues can mislead:
- "binded" -> "bound".
- Example says tC(shape=(n+2)) but later mentions "== n + 3" and uses bufferB. Should be bufferC and n + 2.
Proposed patch:
- * \brief Helper utility to generate match and bind of arguments.
+ * \brief Helper utility to generate match and bind of arguments.
...
- * So if we call it with f(bufferA, bufferB, bufferC), we need to generate
+ * So if we call it with f(bufferA, bufferB, bufferC), we need to generate
* the following binding sequence:
* - define n = bufferA.shape[0]
* - assert bufferB.shape[0] == 3
- * - assert bufferB.shape[1] == n + 3
+ * - assert bufferC.shape[0] == n + 2🤖 Prompt for AI Agents
In src/transform/arg_binder.h around lines 41 to 60, fix the doc typos and
mismatches: change "binded" to "bound", correct the example so tC(shape=(n+2))
corresponds to bufferC and its assertion reads "assert bufferC.shape[1] == n +
2" (not bufferB and not n + 3), and ensure the text consistently refers to
bufferC where appropriate.
| PrimExpr target_size_for_iter = | ||
| make_const(iter_var_size.dtype(), target_vectorized_size); | ||
| PrimExpr target_size_for_expr = | ||
| make_const(expr.dtype(), target_vectorized_size); | ||
| PrimExpr target_size_for_var = | ||
| make_const(var.dtype(), target_vectorized_size); | ||
| PrimExpr zero = make_const(var.dtype(), 0); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finish the dtype cleanup: use typed zero and cast Range extent.
Great move to make per-dtype target constants. Two follow-ups to avoid int32/int64 mismatches:
- Compare against typed zero instead of raw 0.
- Cast the Range extent for v1 to var.dtype() to avoid mixed dtypes (var vs iter_var_size).
Apply this diff:
- PrimExpr target_size_for_iter =
- make_const(iter_var_size.dtype(), target_vectorized_size);
- PrimExpr target_size_for_expr =
- make_const(expr.dtype(), target_vectorized_size);
- PrimExpr target_size_for_var =
- make_const(var.dtype(), target_vectorized_size);
- PrimExpr zero = make_const(var.dtype(), 0);
+ PrimExpr target_size_for_iter = make_const(iter_var_size.dtype(), target_vectorized_size);
+ PrimExpr target_size_for_expr = make_const(expr.dtype(), target_vectorized_size);
+ PrimExpr target_size_for_var = make_const(var.dtype(), target_vectorized_size);
+ PrimExpr zero_var = make_const(var.dtype(), 0);
+ PrimExpr zero_iter = make_const(iter_var_size.dtype(), 0);
+ PrimExpr zero_expr = make_const(expr.dtype(), 0);
- if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
- 0))
+ if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), zero_iter))
return false;
- if (!analyzer->CanProveEqual(
- FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) {
+ if (!analyzer->CanProveEqual(
+ FloorMod(Substitute(expr, {{var, zero_var}}), target_size_for_expr), zero_expr)) {
return false;
}
- Var v0("v0", var.dtype()), v1("v1", var.dtype());
- analyzer->Bind(v0, Range(zero, target_size_for_var));
- analyzer->Bind(v1, Range(zero, analyzer->Simplify(FloorDiv(
- iter_var_size, target_size_for_iter))));
+ Var v0("v0", var.dtype()), v1("v1", var.dtype());
+ analyzer->Bind(v0, Range(zero_var, target_size_for_var));
+ PrimExpr v1_extent = analyzer->Simplify(FloorDiv(iter_var_size, target_size_for_iter));
+ analyzer->Bind(v1, Range(zero_var, cast(var.dtype(), v1_extent)));
- PrimExpr expr_transformed = analyzer->Simplify(
- Substitute(expr, {{var, v0 + v1 * target_size_for_var}}));
+ PrimExpr expr_transformed =
+ analyzer->Simplify(Substitute(expr, {{var, v0 + v1 * target_size_for_var}}));Also applies to: 273-276, 279-291
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 265-272 (and similarly for
273-276 and 279-291), the code uses an untyped literal 0 and mixed dtypes when
constructing target sizes and comparing Range extents; change the zero to a
typed zero created with make_const(var.dtype(), 0) (already declared as zero)
and cast the Range extent (v1) to var.dtype() when used (e.g., wrap v1 with
Cast(var.dtype(), v1) or make_const with var.dtype()) so all comparisons and
size constants use the same dtype and avoid int32/int64 mismatches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/transform/arg_binder.cc (1)
221-227: Fix inconsistent Var usage and type annotation for shape handle.The past review comment correctly identifies this issue. The code creates
v_shapeon line 221 and maps it indef_handle_dtype_on line 222, but the actual buffer usesbuf_shape->data(line 224). This inconsistency means downstream code expectingbuf_shape->datain the map will not find it.Additionally, line 222 uses
make_const(tvm_shape_type, 0)while strides (line 242) and data (line 367) usetir::TypeAnnotation(...), creating another inconsistency.Apply this diff to fix both issues:
- Var v_shape(shape_handle_name(), DataType::Handle()); - def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); + def_handle_dtype_.Set(buf_shape->data, tir::TypeAnnotation(tvm_shape_type));Based on past review comments.
🧹 Nitpick comments (1)
.clang-tidy (1)
7-7: Consider simplifying the regex pattern for clarity.The negative lookahead pattern
(?!.*(?:/|^)(3rdparty|tvm)/)works but is somewhat unclear because^inside(?:/|^)always refers to the string start, making the grouping confusing.A clearer equivalent would be:
-HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*' +HeaderFilterRegex: '^(?!.*/(?:3rdparty|tvm)/|^(?:3rdparty|tvm)/).*'This explicitly separates the two cases: paths containing
/3rdparty/or/tvm/, and paths starting with3rdparty/ortvm/.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
.clang-tidy(1 hunks)src/transform/arg_binder.cc(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/arg_binder.cc (2)
tilelang/language/ast/ir.py (3)
LetStmt(880-908)handle(1467-1497)decl_buffer(1137-1205)tilelang/language/tir/op.py (3)
truncmod(3047-3070)isnullptr(2649-2665)if_then_else(2907-2937)
🔇 Additional comments (6)
src/transform/arg_binder.cc (6)
40-52: LGTM!The assertion helper correctly simplifies conditions and handles three cases: always-false (fatal), always-true (no-op), and conditional (runtime assert).
54-81: LGTM!The binding logic correctly handles new variable introductions, existing bindings, and non-variable arguments with appropriate dtype checks and assertions.
83-93: LGTM!Array binding correctly validates size equality and binds each element with proper naming.
95-156: LGTM!Buffer binding comprehensively validates scope, dtype, alignment, offset factors, and shape/strides with appropriate fuzzy matching support.
228-237: Verify the sub-byte dtype special case is intentional.The code skips shape element binding for
Int(4),UInt(4), andInt(1)dtypes by breaking out of the loop early. Please confirm this is the intended behavior, as it means shape elements are not bound or validated for these sub-byte types.If this special handling is necessary, consider adding a comment explaining why sub-byte types bypass shape binding.
238-373: LGTM!The stride, offset, device, and data pointer binding logic is comprehensive:
- Handles three buffer types (compact, auto-broadcast, explicit strides) correctly
- Properly distinguishes constant vs. variable offsets
- Includes appropriate NULL checks with size-0 array special case
- Consistently uses
TypeAnnotationfor strides and data pointer (lines 242, 367)
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value, | ||
| const std::string &arg_name, bool fuzzy_match) { | ||
| ICHECK_EQ(arg.scope(), value.scope()) | ||
| << "Argument " << arg_name << " Buffer bind scope mismatch"; | ||
| ICHECK_EQ(arg->dtype, value->dtype) | ||
| << "Argument " << arg_name << " Buffer bind data type mismatch"; | ||
| if (value->data_alignment % arg->data_alignment != 0) { | ||
| LOG(WARNING) << "Trying to bind buffer to another one with lower alignment " | ||
| "requirement " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid modulo by zero when comparing buffer alignments
The new ArgBinder::BindBuffer warns when the provided buffer’s alignment is smaller than the required one by computing value->data_alignment % arg->data_alignment. However, BufferNode::data_alignment is zero by default when no alignment is requested. If a symbolic buffer without an explicit alignment is bound here (the common case), this expression executes a modulo with a zero divisor and triggers undefined behaviour before any warning is issued. The guard should skip the check when arg->data_alignment is zero or treat it as 1/“no requirement” to avoid a runtime crash in the binder.
Useful? React with 👍 / 👎.
…alues) (tile-ai#1119) * fix int32 dtype issue * lint fix * lint * lint fix --------- Co-authored-by: Zhiwen Mo <[email protected]>
This pull request introduces a new helper utility for argument binding in the TVM codebase, specifically under the
tvm::tlnamespace. The main addition is the implementation ofArgBinder, which provides a consistent way to match and bind function arguments, handle symbolic buffers, and generate necessary assertions and initializations. Several supporting changes were made to integrate this new utility and improve type safety in related vectorization logic.Key changes:
New Argument Binding Utility
arg_binder.ccandarg_binder.himplementing theArgBinderclass, which provides methods to bind primitive expressions, arrays, buffers, and DLTensor handles, while generating necessary variable definitions, assertions, and initialization statements. This utility is designed to standardize argument binding and constraint checking across TVM transformations. [1] [2]Integration and Refactoring
make_packed_api.ccto use the newarg_binder.hpath, replacing the previous include fromtir/transforms/arg_binder.hwith the local header.Vectorization Improvements
IndiceCanVectorizewithinloop_vectorize.ccby ensuring that all constants and variables used in vectorization checks and substitutions match the relevant data types. This reduces the risk of subtle bugs due to type mismatches.Minor Cleanups
make_packed_api.ccfor consistency.Summary by CodeRabbit
New Features
Refactor
Style
Chores