[Enhancement] Improve handling of negative indices for ramp and broadcast node#1207
Conversation
…ve_index pass * Added logic to handle scalar and vector indices separately, enhancing the ability to determine non-negativity and negativity of indices. * Introduced detailed logging for cases where non-negativity cannot be proven, improving debugging capabilities. * Refactored index state determination for vector types, including support for Ramp and Broadcast nodes.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughEnhanced index legalization logic in a single transformation file. Scalar index handling now proves both non-negativity and negativity conditions. Vector index handling added with per-pattern analysis for RampNode and BroadcastNode to determine state (NonNegative, Negative, or Unknown). State propagation updated accordingly with appropriate logging when proofs fail. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/transform/legalize_negative_index.cc (1)
94-110: Consider applying the fallback pattern to scalar indices.The
BroadcastNodehandling uses a two-tier approach: first attemptingCanProve, then falling back toconst_int_bound. This is more robust than the scalar handling (lines 42-51), which only usesCanProve. Applying the same fallback pattern to scalar indices could reduce false unknowns.For consistency, consider refactoring scalar index handling:
if (simplified.dtype().lanes() == 1) { if (analyzer_.CanProve(simplified >= 0)) { states.push_back(IndexSignState::kNonNegative); continue; } if (analyzer_.CanProve(simplified < 0)) { states.push_back(IndexSignState::kNegative); needs_record = true; continue; } + // Fallback to bound analysis if proof unavailable + auto bound = analyzer_.const_int_bound(simplified); + if (bound->min_value >= 0) { + states.push_back(IndexSignState::kNonNegative); + continue; + } + if (bound->max_value < 0) { + states.push_back(IndexSignState::kNegative); + needs_record = true; + continue; + } states.push_back(IndexSignState::kUnknown); needs_record = true; LOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " << simplified << " for buffer " << load->buffer->name << " (axis " << i << ")."; continue; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/transform/legalize_negative_index.cc(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
src/transform/legalize_negative_index.cc (3)
42-58: Good improvement for scalar index handling.The explicit separation of scalar index analysis improves code clarity. The logic correctly handles the three states (non-negative, negative, unknown) with appropriate warning messages for debugging.
113-121: LGTM!The control flow correctly handles the determined vector states (non-negative and negative) by pushing the state and continuing, leaving the unknown case to fall through to the existing warning logic.
73-85: Manual verification required for overflow and sentinel value handling.The code contains arithmetic operations that could potentially overflow (
s_min * (lanes - 1)ands_max * (lanes - 1)), but the author's comment states this is intentionally not guarded because sentinel values (representing ±∞) are used for bounds.While I confirmed the code matches the review comment and found similar patterns throughout the codebase without explicit overflow checks, I was unable to locate the
ConstIntBoundclass definition to verify:
- Which sentinel constants are actually used
- How arithmetic operations behave with these sentinel values
- Whether the author's claim is technically correct
Please verify whether TVM's sentinel value design actually prevents the overflow issue or if overflow guards should be added.
…rencing lanes to obtain the correct integer value.
| int64_t base_min = base_bound->min_value; | ||
| int64_t base_max = base_bound->max_value; | ||
| int64_t s_min = stride_bound->min_value; | ||
| int64_t s_max = stride_bound->max_value; | ||
|
|
||
| // Guard against overflow is not strictly necessary here because | ||
| // bounds may be +/-inf represented by sentinel values. | ||
| int64_t lower = base_min; | ||
| if (s_min < 0) | ||
| lower += s_min * (lanes - 1); | ||
| int64_t upper = base_max; | ||
| if (s_max > 0) | ||
| upper += s_max * (lanes - 1); |
There was a problem hiding this comment.
Guard against overflow when combining ramp bounds
Here lower += s_min * (lanes - 1) and the analogous upper update operate directly on the results of const_int_bound. When either base_bound or stride_bound returns the sentinel ±∞ (the common case when the analyzer can’t tighten a bound), these additions run on INT64_MIN/INT64_MAX. That triggers signed overflow, which is UB in C++, and in practice can flip the sign so we misclassify a definitely-negative ramp as non-negative (or vice versa). That’s a correctness breaker for the legalization pass.
Please short-circuit before the arithmetic (e.g. if either operand is ConstIntBoundNode::kNegInf/kPosInf, keep the sentinel) or compute the ramp endpoints via the analyzer instead of manual int64 math so we never overflow.
- int64_t lower = base_min;
- if (s_min < 0)
- lower += s_min * (lanes - 1);
- int64_t upper = base_max;
- if (s_max > 0)
- upper += s_max * (lanes - 1);
+ int64_t lower = base_min;
+ int64_t upper = base_max;
+ int64_t lane_span = static_cast<int64_t>(lanes - 1);
+ if (s_min < 0) {
+ if (lower == arith::ConstIntBoundNode::kNegInf ||
+ lower == arith::ConstIntBoundNode::kPosInf ||
+ s_min == arith::ConstIntBoundNode::kNegInf ||
+ s_min == arith::ConstIntBoundNode::kPosInf) {
+ lower = arith::ConstIntBoundNode::kNegInf;
+ } else {
+ __int128 tmp = static_cast<__int128>(lower) +
+ static_cast<__int128>(s_min) * lane_span;
+ if (tmp <= std::numeric_limits<int64_t>::min()) {
+ lower = arith::ConstIntBoundNode::kNegInf;
+ } else if (tmp >= std::numeric_limits<int64_t>::max()) {
+ lower = arith::ConstIntBoundNode::kPosInf;
+ } else {
+ lower = static_cast<int64_t>(tmp);
+ }
+ }
+ }
+ if (s_max > 0) {
+ if (upper == arith::ConstIntBoundNode::kPosInf ||
+ upper == arith::ConstIntBoundNode::kNegInf ||
+ s_max == arith::ConstIntBoundNode::kPosInf ||
+ s_max == arith::ConstIntBoundNode::kNegInf) {
+ upper = arith::ConstIntBoundNode::kPosInf;
+ } else {
+ __int128 tmp = static_cast<__int128>(upper) +
+ static_cast<__int128>(s_max) * lane_span;
+ if (tmp >= std::numeric_limits<int64_t>::max()) {
+ upper = arith::ConstIntBoundNode::kPosInf;
+ } else if (tmp <= std::numeric_limits<int64_t>::min()) {
+ upper = arith::ConstIntBoundNode::kNegInf;
+ } else {
+ upper = static_cast<int64_t>(tmp);
+ }
+ }
+ }🤖 Prompt for AI Agents
In src/transform/legalize_negative_index.cc around lines 73 to 85, the code adds
base_min and s_min*(lanes-1) (and analogously for upper) without handling
sentinel +/-INF from const_int_bound, which causes signed overflow and UB when
base_bound or stride_bound is kNegInf/kPosInf; modify the logic to short-circuit
whenever base_bound->min_value or stride_bound->min_value (and the corresponding
max values) equal ConstIntBoundNode::kNegInf or kPosInf so that you preserve the
sentinel instead of performing the multiplication/addition, or else compute the
ramp endpoints using the analyzer APIs that return ConstIntBoundNodes rather
than doing raw int64 math, ensuring no arithmetic is performed on
INT64_MIN/INT64_MAX sentinels.
…r TIR operations. This addition supports improved functionality and maintainability of the transformation logic.
…cast node (tile-ai#1207) * [Enhancement] Improve handling of negative indices in legalize_negative_index pass * Added logic to handle scalar and vector indices separately, enhancing the ability to determine non-negativity and negativity of indices. * Introduced detailed logging for cases where non-negativity cannot be proven, improving debugging capabilities. * Refactored index state determination for vector types, including support for Ramp and Broadcast nodes. * Fix incorrect lane handling in legalize_negative_index pass by dereferencing lanes to obtain the correct integer value. * Enhance legalize_negative_index pass by including necessary header for TIR operations. This addition supports improved functionality and maintainability of the transformation logic.
Summary by CodeRabbit