diff --git a/src/transform/legalize_negative_index.cc b/src/transform/legalize_negative_index.cc index a1713d835..150be61bb 100644 --- a/src/transform/legalize_negative_index.cc +++ b/src/transform/legalize_negative_index.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -37,12 +38,84 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer { for (size_t i = 0; i < op->indices.size(); ++i) { PrimExpr simplified = analyzer_.Simplify(op->indices[i]); - if (analyzer_.CanProve(simplified >= 0)) { - states.push_back(IndexSignState::kNonNegative); + + // Handle scalar indices with the standard analyzer + if (simplified.dtype().lanes() == 1) { + if (analyzer_.CanProve(simplified >= 0)) { + states.push_back(IndexSignState::kNonNegative); + continue; + } + if (analyzer_.CanProve(simplified < 0)) { + states.push_back(IndexSignState::kNegative); + needs_record = true; + continue; + } + states.push_back(IndexSignState::kUnknown); + needs_record = true; + LOG(WARNING) + << "LegalizeNegativeIndex: cannot prove non-negative index " + << simplified << " for buffer " << load->buffer->name << " (axis " + << i << ")."; continue; } - if (analyzer_.CanProve(simplified < 0)) { + // Vector indices: try to reason about non-negativity/negativity + // Common patterns are Ramp(base, stride, lanes) and Broadcast(value, + // lanes). + IndexSignState vec_state = IndexSignState::kUnknown; + if (const auto *ramp = simplified.as()) { + // Compute a safe lower/upper bound for the vector lanes + // lower_bound = base_min + min(0, stride_min) * (lanes - 1) + // upper_bound = base_max + max(0, stride_max) * (lanes - 1) + auto base_bound = analyzer_.const_int_bound(ramp->base); + auto stride_bound = analyzer_.const_int_bound(ramp->stride); + int lanes = *as_const_int(ramp->lanes); + + int64_t base_min = base_bound->min_value; + int64_t base_max = base_bound->max_value; + int64_t s_min = stride_bound->min_value; + int64_t s_max = stride_bound->max_value; + + // Guard against overflow is not strictly necessary here because + // bounds may be +/-inf represented by sentinel values. + int64_t lower = base_min; + if (s_min < 0) + lower += s_min * (lanes - 1); + int64_t upper = base_max; + if (s_max > 0) + upper += s_max * (lanes - 1); + + if (lower >= 0) { + vec_state = IndexSignState::kNonNegative; + } else if (upper < 0) { + vec_state = IndexSignState::kNegative; + } else { + vec_state = IndexSignState::kUnknown; + } + } else if (const auto *bc = simplified.as()) { + auto v = analyzer_.Simplify(bc->value); + if (analyzer_.CanProve(v >= 0)) { + vec_state = IndexSignState::kNonNegative; + } else if (analyzer_.CanProve(v < 0)) { + vec_state = IndexSignState::kNegative; + } else { + // Try const bound if proof unavailable + auto vb = analyzer_.const_int_bound(v); + if (vb->min_value >= 0) { + vec_state = IndexSignState::kNonNegative; + } else if (vb->max_value < 0) { + vec_state = IndexSignState::kNegative; + } else { + vec_state = IndexSignState::kUnknown; + } + } + } + + if (vec_state == IndexSignState::kNonNegative) { + states.push_back(IndexSignState::kNonNegative); + continue; + } + if (vec_state == IndexSignState::kNegative) { states.push_back(IndexSignState::kNegative); needs_record = true; continue;