[Enhancement][Subtype] Enhance symbolic shape/stride handling for subtype#1599
[Enhancement][Subtype] Enhance symbolic shape/stride handling for subtype#1599LeiWang1999 merged 2 commits intotile-ai:mainfrom
Conversation
…rgBinder * Updated the shape binding logic in ArgBinder to handle subbyte types (bits < 8) by introducing a packing factor for logical shapes. * Enhanced the handling of strides for subbyte types, ensuring correct mapping between logical and runtime strides. * Refactored the binding process to accommodate both symbolic and constant dimensions, improving the robustness of the shape binding mechanism. * Removed obsolete test file `test_ir_kernel_frame.py` as part of cleanup efforts.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughIntroduces specialized subtype (bits < 8) DLTensor shape and stride binding in ArgBinder::BindDLTensors with packing-aware logic. Adds comprehensive test coverage for subtype symbolic shape and stride bindings across dynamic configurations. Removes a placeholder TODO comment from an unrelated test file. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/transform/arg_binder.cc (1)
600-605: Consider extracting pack_factor calculation to reduce duplication.The pack_factor calculation (
8 / buffer->dtype.bits()) is duplicated in both the shape binding block (line 605) and the stride binding block (line 790). Since both are guarded by the samedata_is_subtypecondition, consider computing it once before these blocks.🔎 Proposed refactor
// Bind symbolic variables from buffer shape // For subtype (bits < 8), the runtime tensor has packed shape... + const int pack_factor = data_is_subtype ? (8 / buffer->dtype.bits()) : 1; + if (data_is_subtype) { // For subtype, bind symbolic variables from the packed shape. - // The packing factor k = 8 / bits... - int bits = buffer->dtype.bits(); - int pack_factor = 8 / bits; + // Use pre-computed pack_factorAlso applies to: 785-790
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/transform/arg_binder.cctesting/python/ir/test_ir_kernel_frame.pytesting/python/language/test_tilelang_language_subtype.py
💤 Files with no reviewable changes (1)
- testing/python/ir/test_ir_kernel_frame.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
LetStmt(874-902)decl_buffer(1134-1202)handle(1465-1492)
🪛 Ruff (0.14.10)
testing/python/language/test_tilelang_language_subtype.py
10-10: Unused function argument: x
(ARG001)
19-19: Unused function argument: x
(ARG001)
99-99: Unused function argument: x
(ARG001)
109-109: Unused function argument: x
(ARG001)
120-120: Unused function argument: x
(ARG001)
120-120: Unused function argument: y
(ARG001)
131-131: Unused function argument: x
(ARG001)
131-131: Unused function argument: y
(ARG001)
143-143: Unused function argument: x
(ARG001)
143-143: Unused function argument: y
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (6)
src/transform/arg_binder.cc (2)
595-650: Subtype shape binding logic looks correct.The implementation correctly handles the packed shape semantics where the last dimension's logical size equals
runtime_shape * pack_factor. The use ofBindNullablefor both first-time bindings and consistency assertions is appropriate.One minor observation: the pack_factor calculation at line 605 assumes
bitsevenly divides 8 (valid for 1, 2, 4-bit types). Consider adding a debug assertion if non-power-of-2 subbyte types might be introduced in the future.
651-777: Non-subtype shape binding with cascaded source handling looks correct.The implementation properly handles symbolic variables that appear in multiple buffers by:
- Asserting at least one source buffer is non-null (line 696-697)
- Building a cascaded
if_then_elseexpression from multiple sources (lines 705-751)- Falling back to single-source
BindNullablewhen appropriate (lines 762-763)testing/python/language/test_tilelang_language_subtype.py (4)
9-15: Static analysis warnings are false positives for DSL pattern.The Ruff ARG001 warnings about unused
xandyarguments are false positives. In this TileLang DSL, function parameters are consumed via type annotations (e.g.,x: T.Tensor[(m, 16), T.float4_e2m1fn]) that the@tilelang.lazy_jitdecorator processes to define tensor layouts.This is an expected pattern for embedded DSLs where the decorator introspects annotations rather than traditional argument usage.
142-151: Complex expression kernel correctly tests cross-buffer symbolic relationships.The kernel definition creates an interesting test case where
mandnappear in different expressions across buffers:
xshape(m, n*2)with last-dim pack factor → runtime(m, n)yshape(m*2, n)with last-dim pack factor → runtime(m*2, n/2)This effectively tests that the symbolic solver can correctly derive
mandnfrom the constrained system.
229-245: Good test for shared symbolic stride with non-contiguous tensors.The test correctly creates two tensors with matching runtime strides
(16, 1)via slicing, which translates to matching logical strides(32, 1)after applying pack_factor=2. This validates that the symbolic stride variablesis correctly resolved to the same value from both buffers.
1-7: Comprehensive test suite for subtype shape/stride bindings.The test file provides thorough coverage of the new subtype binding functionality:
- Basic and strided tensor bindings
- Symbolic variables in different dimensions (first, last, shared across buffers)
- Non-contiguous tensor handling
- Complex expressions with multiple symbolic variables
The test structure is well-organized with clear docstrings explaining the expected behavior.
| if (data_is_subtype) { | ||
| // For subtype, only process strides if there are explicit strides | ||
| // with symbolic variables that need binding | ||
| if (!buffer->strides.empty()) { | ||
| int bits = buffer->dtype.bits(); | ||
| int pack_factor = 8 / bits; | ||
|
|
||
| Buffer buf_strides = | ||
| decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())}, | ||
| tvm_shape_type, arg_name + ".strides"); | ||
| def_handle_dtype_.Set(buf_strides->data, | ||
| tir::TypeAnnotation(tvm_shape_type)); | ||
| init_nest_.emplace_back( | ||
| LetStmt(buf_strides->data, | ||
| tvm::if_then_else(Not(is_null), | ||
| TVMArrayGet(DataType::Handle(), handle, | ||
| builtin::kArrStrides), | ||
| make_zero(DataType::Handle())), | ||
| nop)); | ||
| init_nest_.emplace_back(DeclBuffer(buf_strides, nop)); | ||
| PrimExpr v_strides_is_null = | ||
| Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data}); | ||
|
|
||
| for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0; | ||
| --k) { | ||
| DataType stride_dtype = buffer->strides[k].dtype(); | ||
| PrimExpr runtime_stride = | ||
| cast(stride_dtype, | ||
| BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)})); | ||
| runtime_stride = | ||
| tvm::if_then_else(Or(is_null, v_strides_is_null), | ||
| make_const(stride_dtype, 0), runtime_stride); | ||
|
|
||
| // For the last dimension, logical stride = runtime stride | ||
| // For other dimensions, logical stride = runtime stride * pack_factor | ||
| PrimExpr logical_stride_val; | ||
| bool is_last_dim = | ||
| (k == static_cast<int>(buffer->strides.size()) - 1); | ||
| if (is_last_dim) { | ||
| logical_stride_val = runtime_stride; | ||
| } else { | ||
| logical_stride_val = | ||
| runtime_stride * make_const(stride_dtype, pack_factor); | ||
| } | ||
|
|
||
| BindNullable(buffer->strides[k], logical_stride_val, | ||
| stride_element_name(k), true, is_null); | ||
| } | ||
| } |
There was a problem hiding this comment.
Missing compactness validation for subtype buffers with empty strides.
When buffer->strides.empty() is true for subtypes, the code skips stride handling entirely (line 788). However, the non-subtype branch (lines 852-894) validates buffer compactness in this case.
If a subtype buffer is declared contiguous (no explicit strides) but the runtime provides a non-contiguous tensor, this could lead to silent incorrect behavior.
🔎 Suggested fix: Add compactness check for subtype with empty strides
if (data_is_subtype) {
// For subtype, only process strides if there are explicit strides
// with symbolic variables that need binding
if (!buffer->strides.empty()) {
// ... existing stride binding code ...
+ } else {
+ // For subtype with no explicit strides, verify compactness
+ // Similar to non-subtype empty strides handling (lines 852-894)
+ // The packed stride for dimension k is:
+ // packed_stride[k] = product of packed_shape[k+1:]
+ // where packed_shape[-1] = logical_shape[-1] / pack_factor
}
}Committable suggestion skipped: line range outside the PR's diff.
as title. now we constrain the input of a subtype tensor must be in shape [..., m, n // k] where k is the compress factor (for example if the input storage is int8, and target dtype is int4, the k should be 2), as same as stride.
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.