Skip to content

[Enhancement][Subtype] Enhance symbolic shape/stride handling for subtype#1599

Merged
LeiWang1999 merged 2 commits intotile-ai:mainfrom
LeiWang1999:subtype_0104
Jan 4, 2026
Merged

[Enhancement][Subtype] Enhance symbolic shape/stride handling for subtype#1599
LeiWang1999 merged 2 commits intotile-ai:mainfrom
LeiWang1999:subtype_0104

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Jan 4, 2026

as title. now we constrain the input of a subtype tensor must be in shape [..., m, n // k] where k is the compress factor (for example if the input storage is int8, and target dtype is int4, the k should be 2), as same as stride.

Summary by CodeRabbit

  • Bug Fixes

    • Improved handling of sub-byte tensor shapes and strides with dynamic symbolic dimensions.
    • Enhanced support for non-contiguous tensors and proper stride computation.
  • Tests

    • Added comprehensive test suite for sub-byte tensor shape and stride bindings across multiple configurations.

✏️ Tip: You can customize this high-level summary in your review settings.

…rgBinder

* Updated the shape binding logic in ArgBinder to handle subbyte types (bits < 8) by introducing a packing factor for logical shapes.
* Enhanced the handling of strides for subbyte types, ensuring correct mapping between logical and runtime strides.
* Refactored the binding process to accommodate both symbolic and constant dimensions, improving the robustness of the shape binding mechanism.
* Removed obsolete test file `test_ir_kernel_frame.py` as part of cleanup efforts.
@github-actions
Copy link

github-actions bot commented Jan 4, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 4, 2026

📝 Walkthrough

Walkthrough

Introduces specialized subtype (bits < 8) DLTensor shape and stride binding in ArgBinder::BindDLTensors with packing-aware logic. Adds comprehensive test coverage for subtype symbolic shape and stride bindings across dynamic configurations. Removes a placeholder TODO comment from an unrelated test file.

Changes

Cohort / File(s) Summary
Subtype Shape & Stride Binding Logic
src/transform/arg_binder.cc
Branches on data_is_subtype to compute packing factors (8 / bits), bind symbolic dimensions from packed runtime shapes, and handle strides with multiplied factors for non-last dimensions. Adds NULL guards, BindNullable calls, and runtime assertions for consistency. Preserves existing behavior for non-subtype paths.
Test Suite Removals
testing/python/ir/test_ir_kernel_frame.py
Removes single TODO placeholder comment.
Subtype Binding Test Suite
testing/python/language/test_tilelang_language_subtype.py
Adds comprehensive test coverage for subtype shape/stride bindings including basic shapes, strided tensors, symbolic dimensions (last-dim, shared across buffers, complex expressions), non-contiguous tensors, and iterative validation across m, n, and stride configurations. Includes multiple kernel stubs and CUDA-gated test entry points.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

Possibly related PRs

Poem

🐰 Through tensor shapes both packed and wide,
A rabbit binds with packing pride,
Subtype strides now multiply,
Symbolic dims and tests align,
Hopping forward—all is fine!

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly and accurately summarizes the main change: enhancing symbolic shape and stride handling for subtype tensors.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/transform/arg_binder.cc (1)

600-605: Consider extracting pack_factor calculation to reduce duplication.

The pack_factor calculation (8 / buffer->dtype.bits()) is duplicated in both the shape binding block (line 605) and the stride binding block (line 790). Since both are guarded by the same data_is_subtype condition, consider computing it once before these blocks.

🔎 Proposed refactor
     // Bind symbolic variables from buffer shape
     // For subtype (bits < 8), the runtime tensor has packed shape...
+    const int pack_factor = data_is_subtype ? (8 / buffer->dtype.bits()) : 1;
+
     if (data_is_subtype) {
       // For subtype, bind symbolic variables from the packed shape.
-      // The packing factor k = 8 / bits...
-      int bits = buffer->dtype.bits();
-      int pack_factor = 8 / bits;
+      // Use pre-computed pack_factor

Also applies to: 785-790

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 946f611 and f35a2cf.

📒 Files selected for processing (3)
  • src/transform/arg_binder.cc
  • testing/python/ir/test_ir_kernel_frame.py
  • testing/python/language/test_tilelang_language_subtype.py
💤 Files with no reviewable changes (1)
  • testing/python/ir/test_ir_kernel_frame.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/arg_binder.cc (1)
tilelang/language/ast/ir.py (3)
  • LetStmt (874-902)
  • decl_buffer (1134-1202)
  • handle (1465-1492)
🪛 Ruff (0.14.10)
testing/python/language/test_tilelang_language_subtype.py

10-10: Unused function argument: x

(ARG001)


19-19: Unused function argument: x

(ARG001)


99-99: Unused function argument: x

(ARG001)


109-109: Unused function argument: x

(ARG001)


120-120: Unused function argument: x

(ARG001)


120-120: Unused function argument: y

(ARG001)


131-131: Unused function argument: x

(ARG001)


131-131: Unused function argument: y

(ARG001)


143-143: Unused function argument: x

(ARG001)


143-143: Unused function argument: y

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (6)
src/transform/arg_binder.cc (2)

595-650: Subtype shape binding logic looks correct.

The implementation correctly handles the packed shape semantics where the last dimension's logical size equals runtime_shape * pack_factor. The use of BindNullable for both first-time bindings and consistency assertions is appropriate.

One minor observation: the pack_factor calculation at line 605 assumes bits evenly divides 8 (valid for 1, 2, 4-bit types). Consider adding a debug assertion if non-power-of-2 subbyte types might be introduced in the future.


651-777: Non-subtype shape binding with cascaded source handling looks correct.

The implementation properly handles symbolic variables that appear in multiple buffers by:

  1. Asserting at least one source buffer is non-null (line 696-697)
  2. Building a cascaded if_then_else expression from multiple sources (lines 705-751)
  3. Falling back to single-source BindNullable when appropriate (lines 762-763)
testing/python/language/test_tilelang_language_subtype.py (4)

9-15: Static analysis warnings are false positives for DSL pattern.

The Ruff ARG001 warnings about unused x and y arguments are false positives. In this TileLang DSL, function parameters are consumed via type annotations (e.g., x: T.Tensor[(m, 16), T.float4_e2m1fn]) that the @tilelang.lazy_jit decorator processes to define tensor layouts.

This is an expected pattern for embedded DSLs where the decorator introspects annotations rather than traditional argument usage.


142-151: Complex expression kernel correctly tests cross-buffer symbolic relationships.

The kernel definition creates an interesting test case where m and n appear in different expressions across buffers:

  • x shape (m, n*2) with last-dim pack factor → runtime (m, n)
  • y shape (m*2, n) with last-dim pack factor → runtime (m*2, n/2)

This effectively tests that the symbolic solver can correctly derive m and n from the constrained system.


229-245: Good test for shared symbolic stride with non-contiguous tensors.

The test correctly creates two tensors with matching runtime strides (16, 1) via slicing, which translates to matching logical strides (32, 1) after applying pack_factor=2. This validates that the symbolic stride variable s is correctly resolved to the same value from both buffers.


1-7: Comprehensive test suite for subtype shape/stride bindings.

The test file provides thorough coverage of the new subtype binding functionality:

  • Basic and strided tensor bindings
  • Symbolic variables in different dimensions (first, last, shared across buffers)
  • Non-contiguous tensor handling
  • Complex expressions with multiple symbolic variables

The test structure is well-organized with clear docstrings explaining the expected behavior.

Comment on lines +785 to +833
if (data_is_subtype) {
// For subtype, only process strides if there are explicit strides
// with symbolic variables that need binding
if (!buffer->strides.empty()) {
int bits = buffer->dtype.bits();
int pack_factor = 8 / bits;

Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data,
tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(buf_strides->data,
tvm::if_then_else(Not(is_null),
TVMArrayGet(DataType::Handle(), handle,
builtin::kArrStrides),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
PrimExpr v_strides_is_null =
Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});

for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0;
--k) {
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr runtime_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
runtime_stride =
tvm::if_then_else(Or(is_null, v_strides_is_null),
make_const(stride_dtype, 0), runtime_stride);

// For the last dimension, logical stride = runtime stride
// For other dimensions, logical stride = runtime stride * pack_factor
PrimExpr logical_stride_val;
bool is_last_dim =
(k == static_cast<int>(buffer->strides.size()) - 1);
if (is_last_dim) {
logical_stride_val = runtime_stride;
} else {
logical_stride_val =
runtime_stride * make_const(stride_dtype, pack_factor);
}

BindNullable(buffer->strides[k], logical_stride_val,
stride_element_name(k), true, is_null);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing compactness validation for subtype buffers with empty strides.

When buffer->strides.empty() is true for subtypes, the code skips stride handling entirely (line 788). However, the non-subtype branch (lines 852-894) validates buffer compactness in this case.

If a subtype buffer is declared contiguous (no explicit strides) but the runtime provides a non-contiguous tensor, this could lead to silent incorrect behavior.

🔎 Suggested fix: Add compactness check for subtype with empty strides
     if (data_is_subtype) {
       // For subtype, only process strides if there are explicit strides
       // with symbolic variables that need binding
       if (!buffer->strides.empty()) {
         // ... existing stride binding code ...
+      } else {
+        // For subtype with no explicit strides, verify compactness
+        // Similar to non-subtype empty strides handling (lines 852-894)
+        // The packed stride for dimension k is:
+        //   packed_stride[k] = product of packed_shape[k+1:] 
+        // where packed_shape[-1] = logical_shape[-1] / pack_factor
       }
     }

Committable suggestion skipped: line range outside the PR's diff.

@LeiWang1999 LeiWang1999 merged commit e3c9a58 into tile-ai:main Jan 4, 2026
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant