Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions src/transform/legalize_negative_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -37,12 +38,84 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {

for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);

// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
LOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue;
}

if (analyzer_.CanProve(simplified < 0)) {
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState vec_state = IndexSignState::kUnknown;
if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);

int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;

// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
Comment on lines +74 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against overflow when combining ramp bounds

Here lower += s_min * (lanes - 1) and the analogous upper update operate directly on the results of const_int_bound. When either base_bound or stride_bound returns the sentinel ±∞ (the common case when the analyzer can’t tighten a bound), these additions run on INT64_MIN/INT64_MAX. That triggers signed overflow, which is UB in C++, and in practice can flip the sign so we misclassify a definitely-negative ramp as non-negative (or vice versa). That’s a correctness breaker for the legalization pass.

Please short-circuit before the arithmetic (e.g. if either operand is ConstIntBoundNode::kNegInf/kPosInf, keep the sentinel) or compute the ramp endpoints via the analyzer instead of manual int64 math so we never overflow.

-        int64_t lower = base_min;
-        if (s_min < 0)
-          lower += s_min * (lanes - 1);
-        int64_t upper = base_max;
-        if (s_max > 0)
-          upper += s_max * (lanes - 1);
+        int64_t lower = base_min;
+        int64_t upper = base_max;
+        int64_t lane_span = static_cast<int64_t>(lanes - 1);
+        if (s_min < 0) {
+          if (lower == arith::ConstIntBoundNode::kNegInf ||
+              lower == arith::ConstIntBoundNode::kPosInf ||
+              s_min == arith::ConstIntBoundNode::kNegInf ||
+              s_min == arith::ConstIntBoundNode::kPosInf) {
+            lower = arith::ConstIntBoundNode::kNegInf;
+          } else {
+            __int128 tmp = static_cast<__int128>(lower) +
+                           static_cast<__int128>(s_min) * lane_span;
+            if (tmp <= std::numeric_limits<int64_t>::min()) {
+              lower = arith::ConstIntBoundNode::kNegInf;
+            } else if (tmp >= std::numeric_limits<int64_t>::max()) {
+              lower = arith::ConstIntBoundNode::kPosInf;
+            } else {
+              lower = static_cast<int64_t>(tmp);
+            }
+          }
+        }
+        if (s_max > 0) {
+          if (upper == arith::ConstIntBoundNode::kPosInf ||
+              upper == arith::ConstIntBoundNode::kNegInf ||
+              s_max == arith::ConstIntBoundNode::kPosInf ||
+              s_max == arith::ConstIntBoundNode::kNegInf) {
+            upper = arith::ConstIntBoundNode::kPosInf;
+          } else {
+            __int128 tmp = static_cast<__int128>(upper) +
+                           static_cast<__int128>(s_max) * lane_span;
+            if (tmp >= std::numeric_limits<int64_t>::max()) {
+              upper = arith::ConstIntBoundNode::kPosInf;
+            } else if (tmp <= std::numeric_limits<int64_t>::min()) {
+              upper = arith::ConstIntBoundNode::kNegInf;
+            } else {
+              upper = static_cast<int64_t>(tmp);
+            }
+          }
+        }
🤖 Prompt for AI Agents
In src/transform/legalize_negative_index.cc around lines 73 to 85, the code adds
base_min and s_min*(lanes-1) (and analogously for upper) without handling
sentinel +/-INF from const_int_bound, which causes signed overflow and UB when
base_bound or stride_bound is kNegInf/kPosInf; modify the logic to short-circuit
whenever base_bound->min_value or stride_bound->min_value (and the corresponding
max values) equal ConstIntBoundNode::kNegInf or kPosInf so that you preserve the
sentinel instead of performing the multiplication/addition, or else compute the
ramp endpoints using the analyzer APIs that return ConstIntBoundNodes rather
than doing raw int64 math, ensuring no arithmetic is performed on
INT64_MIN/INT64_MAX sentinels.


if (lower >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (upper < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
} else if (const auto *bc = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(bc->value);
if (analyzer_.CanProve(v >= 0)) {
vec_state = IndexSignState::kNonNegative;
} else if (analyzer_.CanProve(v < 0)) {
vec_state = IndexSignState::kNegative;
} else {
// Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
}
}

if (vec_state == IndexSignState::kNonNegative) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
Expand Down
Loading