Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/support/ffi_aliases.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
Expand Down
214 changes: 110 additions & 104 deletions src/transform/legalize_negative_index.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
* \brief Legalize negative indices in buffer load/store expressions.
*/

#include <tvm/ffi/reflection/registry.h>
Expand All @@ -10,6 +10,7 @@
#include <tvm/tir/transform.h>

#include <unordered_map>
#include <variant>
#include <vector>

#include "arith/ir_mutator_with_analyzer.h"
Expand All @@ -23,47 +24,42 @@ using arith::IRVisitorWithAnalyzer;

enum class IndexSignState { kNonNegative, kNegative, kUnknown };

using BufferAccessVariant =
std::variant<const BufferLoadNode *, const BufferStoreNode *>;
using LoadStore2StateMap =
std::unordered_map<BufferAccessVariant, std::vector<IndexSignState>>;

class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
explicit NegativeIndexAnalyzer(
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
explicit NegativeIndexAnalyzer(LoadStore2StateMap *result)
: result_(result) {}

void VisitExpr_(const BufferLoadNode *op) final {
auto load = tvm::ffi::GetRef<BufferLoad>(op);
private:
std::vector<IndexSignState> ProcessIdx(const ffi::Array<PrimExpr> &indices,
ffi::String buffer_name) {
std::vector<IndexSignState> states;
states.reserve(op->indices.size());
bool needs_record = false;
states.reserve(indices.size());

for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
for (size_t i = 0; i < indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown;

// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue;
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}

// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState vec_state = IndexSignState::kUnknown;
if (const auto *ramp = simplified.as<RampNode>()) {
else if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
Expand All @@ -85,127 +81,137 @@ class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
if (s_max > 0)
upper += s_max * (lanes - 1);

if (lower >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (upper < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
} else if (const auto *bc = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(bc->value);
if (analyzer_.CanProve(v >= 0)) {
vec_state = IndexSignState::kNonNegative;
} else if (analyzer_.CanProve(v < 0)) {
vec_state = IndexSignState::kNegative;
} else {
if (lower >= 0)
state = IndexSignState::kNonNegative;
else if (upper < 0)
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(broadcast->value);
if (analyzer_.CanProve(v >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(v < 0))
state = IndexSignState::kNegative;
else {
// Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
if (vb->min_value >= 0)
state = IndexSignState::kNonNegative;
else if (vb->max_value < 0)
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
}
states.push_back(state);
}

if (vec_state == IndexSignState::kNonNegative) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
return std::move(states);
}

states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ").";
}
bool NeedRecord(const std::vector<IndexSignState> &states) {
return std::any_of(states.begin(), states.end(),
[](const IndexSignState &state) {
return state == IndexSignState::kUnknown ||
state == IndexSignState::kNegative;
});
}

void VisitExpr_(const BufferLoadNode *op) final {
std::vector<IndexSignState> states =
ProcessIdx(op->indices, op->buffer->name);

if (needs_record) {
if (NeedRecord(states))
(*result_)[op] = std::move(states);
}

IRVisitorWithAnalyzer::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode *op) final {
std::vector<IndexSignState> states =
ProcessIdx(op->indices, op->buffer->name);

if (NeedRecord(states))
(*result_)[op] = std::move(states);

IRVisitorWithAnalyzer::VisitStmt_(op);
}

private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result_;
LoadStore2StateMap *result_;
};

class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
static PrimFunc Apply(PrimFunc func, const LoadStore2StateMap &states) {
arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body);
return func;
}

private:
NegativeIndexRewriter(
arith::Analyzer *analyzer,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
NegativeIndexRewriter(arith::Analyzer *analyzer,
const LoadStore2StateMap &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}

ffi::Array<PrimExpr> UpdateIdx(const ffi::Array<PrimExpr> &indices,
const ffi::Array<PrimExpr> &buffer_shape,
const std::vector<IndexSignState> &state_vec) {
ICHECK_EQ(state_vec.size(), indices.size())
<< "State vector size mismatch for buffer load/store indices ("
<< indices << ")";
ffi::Array<PrimExpr> new_indices = indices;
for (size_t i = 0; i < indices.size(); ++i) {
if (state_vec[i] != IndexSignState::kNegative)
continue;
new_indices.Set(i, analyzer_->Simplify(buffer_shape[i] + indices[i]));
}
return new_indices;
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));

auto it = states_.find(op);
if (it == states_.end()) {
if (it == states_.end())
return load;
}

auto indices = load->indices;
bool changed = false;

const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;
auto indices = UpdateIdx(load->indices, load->buffer->shape, it->second);
return BufferLoad(load->buffer, indices, load->predicate);
}

for (size_t i = 0; i < indices.size(); ++i) {
if (state_vector[i] != IndexSignState::kNegative) {
continue;
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store =
Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));

if (!changed) {
return load;
}
auto it = states_.find(op);
if (it == states_.end())
return store;

return BufferLoad(load->buffer, indices);
auto indices = UpdateIdx(store->indices, store->buffer->shape, it->second);
return BufferStore(store->buffer, store->value, indices, store->predicate);
}

const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
&states_;
private:
const LoadStore2StateMap &states_;
};

PrimFunc LegalizeNegativeIndex(PrimFunc func) {
if (!func->body.defined()) {
return func;
}

std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
states;
LoadStore2StateMap states;
NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body);
if (states.empty()) {
Expand Down
Loading
Loading