Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions src/transform/legalize_negative_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,52 +44,28 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown;

// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0))
// Handle vector patterns first to avoid querying lanes() on
// scalable vectors (which is not allowed at compile-time).
if (const auto *ramp = simplified.as<RampNode>()) {
// For scalable vectors, we cannot rely on a constant lane count.
// Use sufficient (but not necessary) conditions:
// - If base >= 0 and stride >= 0, all lanes are non-negative.
// - If base < 0 and stride <= 0, all lanes are negative.
bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
bool base_neg = analyzer_.CanProve(ramp->base < 0);
bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);

if (base_nonneg && stride_nonneg) {
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
} else if (base_neg && stride_nonpos) {
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
else if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);

int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;

// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);

if (lower >= 0)
state = IndexSignState::kNonNegative;
else if (upper < 0)
state = IndexSignState::kNegative;
else
} else {
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(broadcast->value);
if (analyzer_.CanProve(v >= 0))
Expand All @@ -109,6 +85,20 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
} else {
// Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
// Fall back to scalar reasoning. If this expression is actually a
// vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
// Try to prove scalar first; if proof fails, leave as unknown.
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
states.push_back(state);
}
Expand Down
Loading