Skip to content

[ArgBinder] Enhance shape variable handling and assertions#1467

Merged
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:nullable
Dec 19, 2025
Merged

[ArgBinder] Enhance shape variable handling and assertions#1467
LeiWang1999 merged 6 commits intotile-ai:mainfrom
LeiWang1999:nullable

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Dec 18, 2025

  • Implemented special handling for comparing if_then_else expressions to simplify conditions involving NULL checks.
  • Added methods to set shared shape variables and finalize deferred bindings, generating cascading if_then_else expressions and runtime assertions for non-NULL buffers.
  • Updated the binding logic to defer shape variable bindings for shared variables, ensuring proper handling across multiple nullable buffers.

Summary by CodeRabbit

  • New Features

    • Support for nullable/optional buffers that share symbolic shapes, with runtime selection of the first non-null source.
  • Bug Fixes

    • Safer handling of NULL buffer handles and pointers with runtime checks and clearer error messages for mismatched dimensions/dtypes.
  • Refactor

    • Unified batch-style buffer binding to improve handling of multiple related buffers and deferred validation.
  • Tests

    • Added an end-to-end test exercising nullable shared-shape buffer scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

- Implemented special handling for comparing if_then_else expressions to simplify conditions involving NULL checks.
- Added methods to set shared shape variables and finalize deferred bindings, generating cascading if_then_else expressions and runtime assertions for non-NULL buffers.
- Updated the binding logic to defer shape variable bindings for shared variables, ensuring proper handling across multiple nullable buffers.
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 18, 2025

Walkthrough

Replaces single-buffer DLTensor binding with a batch-style BindDLTensors that processes multiple nullable buffers in two passes: collect shared symbolic shape sources, then emit per-buffer NULL guards, guarded shape/stride/device/dtype bindings, and deferred/cascaded runtime checks for shared shapes and pointer consistency.

Changes

Cohort / File(s) Summary
ArgBinder implementation
src/transform/arg_binder.cc
Replaced single-buffer binding with BindDLTensors(...). Implements two-pass binding: first collect symbolic shape vars and sources across buffers; second create per-buffer is_null guards, guarded loads for shape/strides/device/dtype, deferred shape bindings, cascaded binding expressions for shared shape vars, and runtime packed-call assertions on mismatches. Refactored internal maps/lists to support batch bindings and deferred finalization.
ArgBinder API
src/transform/arg_binder.h
Public API changed: removed BindDLTensor(...) and added BindDLTensors(const std::vector<std::pair<Var, Buffer>>&, const PrimExpr&, const PrimExpr&, const std::string&, const std::unordered_set<const VarNode*>&). Updated signatures/comments to reflect batch binding and used-buffer set.
make_packed_api flow
src/transform/make_packed_api.cc
Calls new BindDLTensors(...) during binder phase; removed prior logic that anchored a non-NULL carrier and removed per-buffer display-name binding path. Relies on new binder logic to handle nullable carriers and shared-shape cascades.
Tests
testing/python/transform/test_nullable_buffer_params.py
New test exercising nullable buffer parameters and shared dynamic shape behavior: runs kernel with all buffers present, subsets present, and all-None case (expects runtime error for “at least one non-null buffer”).
Debugging output
tilelang/engine/phase.py
Added two debug print statements after LowerDeviceKernelLaunch in OptimizeForTarget to print a marker and the IRModule; an exit call was commented out.

Sequence Diagram(s)

sequenceDiagram
    participant MPA as make_packed_api
    participant AB as ArgBinder
    participant Buf as Buffers (nullable handles)
    participant Runtime as Runtime (packed calls / assertions)

    MPA->>AB: BindDLTensors(buffer_defs, device_type, device_id, func_name, used_set)
    AB->>AB: First pass — collect symbolic shape vars & source buffers
    loop For each buffer
        Buf->>AB: expose handle (may be NULL)
        AB->>AB: create is_null guard for handle
        AB->>AB: if not NULL -> guarded loads: shape, strides, device, dtype, data ptr
        AB->>AB: record shape sources or defer binding if shared
    end
    AB->>AB: Second pass — materialize bindings
    alt Shared shape vars present
        AB->>AB: build cascaded binding expr from first non-NULL source
        AB->>Runtime: emit packed-call/assert: "≥1 non-NULL source" if all NULL
        AB->>Runtime: emit packed-call/assert: shape equality checks across buffers
    end
    AB->>Runtime: emit other runtime checks (dtype/device/byte-offset) guarded by is_null
    AB->>MPA: return finalized bindings & guarded checks
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Focus review on:
    • src/transform/arg_binder.cc: two-pass collection logic, cascaded binding expression correctness, NULL-guard placement, and emitted packed-call assertions.
    • src/transform/make_packed_api.cc: integration with new API and removal of prior anchoring logic.
    • testing/python/transform/test_nullable_buffer_params.py: test coverage and expected error message matching.

Possibly related PRs

Poem

🐰 I hopped through shapes both shared and sly,

buffers nullable under the sky,
Guards stitched tight, cascades now sing,
One non-null source wakes everything,
Hooray — bindings safe, the rabbit gives a joyful cry!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[ArgBinder] Enhance shape variable handling and assertions' accurately describes the main changes, which involve refactoring shape variable handling, introducing a two-pass processing approach, and adding runtime assertion frameworks for nullable buffers.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/transform/arg_binder.cc (1)

161-232: Verify the cascading expression build order matches the intended semantics.

The cascading expression is built by iterating in reverse order, but the fallback for the last binding (first in reverse) is temp_var directly rather than a sentinel value like 0. This means if all buffers are NULL except the last one processed, the shape will still be extracted correctly.

However, I want to confirm: the comment on line 172 says the fallback should be 0, but the implementation uses the last temp_var as the base. This appears intentional (to avoid an invalid zero shape value), but the comment is slightly misleading.

Also, the pairwise equality constraints (lines 212-230) generate O(n²) assertions for n buffers. For typical use cases with 2-3 buffers this is fine, but worth noting for scalability.

🔎 Consider clarifying the comment to match the implementation:
     // Step 1: Build cascading if_then_else expression
     // Traverse bindings in reverse order to build:
-    // if_then_else(!is_null_0, temp_0, if_then_else(!is_null_1, temp_1, 0))
+    // if_then_else(!is_null_0, temp_0, if_then_else(!is_null_1, temp_1, temp_last))
+    // The last binding's temp_var is used as the final fallback (not 0).
     PrimExpr cascading_expr;
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7248a81 and 2514ef7.

📒 Files selected for processing (3)
  • src/transform/arg_binder.cc (3 hunks)
  • src/transform/arg_binder.h (1 hunks)
  • src/transform/make_packed_api.cc (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (7)
src/transform/make_packed_api.cc (2)

384-413: LGTM! Proper handling of shared shape variables across multiple carriers.

The logic correctly distinguishes between:

  1. Single carrier → must be non-NULL to provide the shape (immediate binding)
  2. Multiple carriers → shape can come from any non-NULL buffer (deferred binding via cascading if_then_else)

This enables nullable buffer support where shape variables can be extracted from whichever buffer is non-NULL at runtime.


524-537: LGTM! Correct ordering of SetSharedShapeVars before buffer binding and finalization after.

The sequence is correct:

  1. SetSharedShapeVars before BindDLTensor calls so the binder knows which shape vars to defer
  2. FinalizeDeferredBindings after all buffers are bound to generate cascading expressions and assertions
src/transform/arg_binder.cc (3)

31-91: Well-designed pattern matching for simplifying if_then_else comparisons.

The logic correctly transforms:

if_then_else(guard_a, val_a, fallback) == if_then_else(guard_b, val_b, fallback)

into:

!guard_a || !guard_b || (val_a == val_b)

This avoids spurious assertion failures when comparing guarded shape values where one or both buffers may be NULL. The Cast unwrapping is also handled correctly.


156-159: LGTM! Simple setter for shared shape variables.


658-689: LGTM! Correct deferred binding logic for shared shape variables.

The implementation properly:

  1. Checks if the shape dimension is a shared variable via shared_shape_vars_.count
  2. Creates a temporary variable to hold the buffer-specific shape value
  3. Records the deferred binding with both temp_var and is_null_var
  4. Falls back to immediate BindNullable for non-shared variables

The naming convention (arg_name + "_shape_dim_" + k + "_temp") provides good traceability for debugging.

src/transform/arg_binder.h (2)

161-176: LGTM! Clear and well-documented public API.

The documentation properly describes:

  1. Purpose of SetSharedShapeVars for marking variables needing special handling
  2. What FinalizeDeferredBindings generates (cascading if_then_else + runtime assertion)

Minor note: The documentation could also mention the pairwise equality constraints that are generated, but this is optional.


179-190: LGTM! Well-structured private state for deferred binding tracking.

The DeferredShapeBinding struct cleanly captures the essential information needed for finalization, and the map structure allows efficient lookup by shape variable node.

…sed code

- Removed deprecated methods for setting shared shape variables and finalizing deferred bindings, streamlining the argument binding process.
- Simplified the logic for handling shape values in the `BindDLTensor` function, ensuring immediate binding for normal shape variables.
- Enhanced clarity by eliminating unnecessary comments and code related to cascading if_then_else expressions for shared variables.
…ndling

- Replaced the single `BindDLTensor` method with `BindDLTensors` to support multiple buffers, improving flexibility in handling DLTensor bindings.
- Introduced a two-pass approach for shape variable handling, allowing for better management of symbolic dimensions and null checks.
- Updated the logic to assert non-null conditions at runtime and utilize cascaded if_then_else expressions for shape retrieval, enhancing robustness.
- Removed deprecated code and streamlined the binding process for clarity and maintainability.
…in test output

- Updated string formatting for better readability in the `test_nullable_shared_shape` function.
- Ensured consistent use of double quotes for string literals.
- Added a missing newline at the end of the file for proper formatting.
…ensors

- Streamlined the calculation of allocation size by replacing a lambda function with a direct loop, enhancing readability and maintainability.
- Improved clarity in the null check message for data pointers, ensuring better understanding of the binding process.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
src/transform/make_packed_api.cc (1)

516-523: Batch binding integration looks correct.

The switch from per-buffer BindDLTensor calls to a single BindDLTensors batch call properly integrates the new nullable buffer handling. All required parameters are passed correctly.

Minor: Extra blank line at 523 could be removed for consistency.

🔎 Apply this diff to remove the extra blank line:
   for (const auto &[var, buffer] : buffer_def) {
     // Prefer buffer data var name in diagnostics to avoid exposing low-level
     // handle vars
     arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
   }
-
   // reset global symbol to attach prefix
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2514ef7 and 7d297b0.

📒 Files selected for processing (5)
  • src/transform/arg_binder.cc (1 hunks)
  • src/transform/arg_binder.h (2 hunks)
  • src/transform/make_packed_api.cc (3 hunks)
  • testing/python/transform/test_nullable_buffer_params.py (1 hunks)
  • tilelang/engine/phase.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.

Applied to files:

  • testing/python/transform/test_nullable_buffer_params.py
🧬 Code graph analysis (2)
testing/python/transform/test_nullable_buffer_params.py (3)
tilelang/language/symbolics.py (1)
  • dynamic (10-21)
tilelang/language/v2/dtypes.py (2)
  • int32 (202-202)
  • float32 (257-257)
tilelang/language/kernel.py (1)
  • threads (214-218)
src/transform/arg_binder.h (1)
src/transform/arg_binder.cc (2)
  • BindDLTensors (314-925)
  • BindDLTensors (314-318)
🪛 Ruff (0.14.8)
testing/python/transform/test_nullable_buffer_params.py

16-16: Unused function argument: a

(ARG001)


17-17: Unused function argument: b

(ARG001)


18-18: Unused function argument: c

(ARG001)


59-59: Consider moving this statement to an else block

(TRY300)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (7)
testing/python/transform/test_nullable_buffer_params.py (1)

7-69: LGTM!

The test comprehensively exercises nullable buffer parameter handling by testing various combinations of NULL and non-NULL buffers. The logic correctly validates that at least one non-NULL buffer is required for shared shape variables.

Note: The static analysis warnings about unused function arguments (a, b, c) are false positives—these parameters are intentionally unused as the test focuses on NULL pointer handling rather than computation.

src/transform/make_packed_api.cc (1)

396-404: Clear documentation of the new nullable binding approach.

The comment effectively explains how the new BindDLTensors method handles nullable carriers through runtime assertions and cascaded if_then_else expressions, eliminating the need to force one carrier to be non-NULL.

src/transform/arg_binder.h (1)

99-112: Well-designed API evolution from single-buffer to batch binding.

The updated BindDLTensors signature properly supports the new nullable buffer handling model by accepting a vector of buffer definitions and a set of used buffers, enabling efficient batch processing with shared shape variable coordination.

src/transform/arg_binder.cc (4)

314-347: Well-structured first pass for shape variable collection.

The first pass correctly collects all symbolic shape variables across buffers and tracks their sources (buffer name, dimension index, handle pointer). This enables efficient coordination of shared shape variables in the second pass.


349-409: Proper NULL guard and shape buffer initialization.

The implementation correctly:

  • Creates per-buffer is_null guards with explicit NULL checks for unused buffers
  • Pre-creates all shape buffers with guarded TVMArrayGet calls to avoid dereferencing NULL handles
  • Emits non-NULL assertions for used buffers

594-698: Sophisticated cascaded binding for shared symbolic shape variables.

The cascaded if_then_else construction for shared shape variables is well-designed:

  • Asserts at least one carrier buffer is non-NULL (lines 603-627)
  • Builds cascaded expression in reverse order to prioritize first non-NULL buffer (lines 635-673)
  • Properly handles used (non-nullable) buffers by always using their shape values (lines 654-666)
  • Falls back to nullable binding for single-source or first-time variables (lines 681-686)

The reverse iteration pattern ensures correct precedence when multiple buffers share the same shape variable.


411-924: Comprehensive and robust batch binding implementation.

The main binding loop properly handles all DLTensor fields with:

  • Guarded field access using if_then_else to prevent NULL dereference throughout
  • Structured error messages via tvm_call_packed with kernel/buffer context
  • Proper dtype compatibility checks including FP8 variants and bool handling
  • Stride binding with auto-broadcast and compact array support
  • Byte offset validation for both constant and symbolic offsets
  • Device type/ID consistency checks
  • Data pointer NULL validation with appropriate guards for size-0 arrays

The implementation correctly coordinates nullable buffers with shared shape variables while maintaining safe memory access patterns.

Removed debug print statements after MakePackedAPI transformation.
@LeiWang1999 LeiWang1999 merged commit f6db201 into tile-ai:main Dec 19, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant