diff --git a/src/transform/legalize_negative_index.cc b/src/transform/legalize_negative_index.cc index f0df555ef..5a84ffac6 100644 --- a/src/transform/legalize_negative_index.cc +++ b/src/transform/legalize_negative_index.cc @@ -44,52 +44,28 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { PrimExpr simplified = analyzer_.Simplify(indices[i]); IndexSignState state = IndexSignState::kUnknown; - // Handle scalar indices with the standard analyzer - if (simplified.dtype().lanes() == 1) { - if (analyzer_.CanProve(simplified >= 0)) + // Handle vector patterns first to avoid querying lanes() on + // scalable vectors (which is not allowed at compile-time). + if (const auto *ramp = simplified.as()) { + // For scalable vectors, we cannot rely on a constant lane count. + // Use sufficient (but not necessary) conditions: + // - If base >= 0 and stride >= 0, all lanes are non-negative. + // - If base < 0 and stride <= 0, all lanes are negative. + bool base_nonneg = analyzer_.CanProve(ramp->base >= 0); + bool base_neg = analyzer_.CanProve(ramp->base < 0); + bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0); + bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0); + + if (base_nonneg && stride_nonneg) { state = IndexSignState::kNonNegative; - else if (analyzer_.CanProve(simplified < 0)) + } else if (base_neg && stride_nonpos) { state = IndexSignState::kNegative; - else - DLOG(WARNING) - << "LegalizeNegativeIndex: cannot prove non-negative index " - << simplified << " for buffer " << buffer_name << " (axis " << i - << ", index " + indices[i]->Script() + ")."; - } - // Vector indices: try to reason about non-negativity/negativity - // Common patterns are Ramp(base, stride, lanes) and Broadcast(value, - // lanes). - else if (const auto *ramp = simplified.as()) { - // Compute a safe lower/upper bound for the vector lanes - // lower_bound = base_min + min(0, stride_min) * (lanes - 1) - // upper_bound = base_max + max(0, stride_max) * (lanes - 1) - auto base_bound = analyzer_.const_int_bound(ramp->base); - auto stride_bound = analyzer_.const_int_bound(ramp->stride); - int lanes = *as_const_int(ramp->lanes); - - int64_t base_min = base_bound->min_value; - int64_t base_max = base_bound->max_value; - int64_t s_min = stride_bound->min_value; - int64_t s_max = stride_bound->max_value; - - // Guard against overflow is not strictly necessary here because - // bounds may be +/-inf represented by sentinel values. - int64_t lower = base_min; - if (s_min < 0) - lower += s_min * (lanes - 1); - int64_t upper = base_max; - if (s_max > 0) - upper += s_max * (lanes - 1); - - if (lower >= 0) - state = IndexSignState::kNonNegative; - else if (upper < 0) - state = IndexSignState::kNegative; - else + } else { DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index " << simplified << " for buffer " << buffer_name << " (axis " << i << ", index " + indices[i]->Script() + ")."; + } } else if (const auto *broadcast = simplified.as()) { auto v = analyzer_.Simplify(broadcast->value); if (analyzer_.CanProve(v >= 0)) @@ -109,6 +85,20 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { << simplified << " for buffer " << buffer_name << " (axis " << i << ", index " + indices[i]->Script() + ")."; } + } else { + // Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes(). + // Fall back to scalar reasoning. If this expression is actually a + // vector-but-not-Ramp/Broadcast, treat as unknown to be safe. + // Try to prove scalar first; if proof fails, leave as unknown. + if (analyzer_.CanProve(simplified >= 0)) + state = IndexSignState::kNonNegative; + else if (analyzer_.CanProve(simplified < 0)) + state = IndexSignState::kNegative; + else + DLOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << buffer_name << " (axis " << i + << ", index " + indices[i]->Script() + ")."; } states.push_back(state); }