Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions src/transform/legalize_safe_memory_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,34 @@ namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;

// GlobalMemChecker for a BufferLoad/BufferStore node:
// SafeMemChecker for a BufferLoad/BufferStore node:
// 1. Identify BufferLoad and BufferStore nodes.
// 2. Check if the buffer is in global scope.
// 3. For each index, compare against the buffer's shape.
// 2. For each index, compare against the buffer's shape.
// If the index might exceed the shape (upper bound too large),
// log a warning or handle accordingly.
struct GlobalMemChecker : public StmtExprVisitor {
// log a warning (local/shared) or handle accordingly (global).
struct SafeMemChecker : public StmtExprVisitor {

GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds)
SafeMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds)
: analyzer_(analyzer),
recursively_collect_conds_(recursively_collect_conds) {}
void VisitExpr_(const BufferLoadNode *op) final {
// Check if the buffer is in global scope
// This is because we are writing TilePrograms, where out of bounds
// accesses only happen in the global buffer.
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true);
}
// If the buffer is in global scope, we will check its indices and add
// corresponding bound checks.
// If the buffer is in shared/local, although out of bound accesses are
// still possible, we assume the developers can handle them. This is because
// we are writing TilePrograms. Therefore we only log warnings if there
// are possible out-of-bounds.
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true,
!IsGlobalBuffer(op->buffer));
if (recursively_collect_conds_) {
StmtExprVisitor::VisitExpr_(op);
}
}

void VisitStmt_(const BufferStoreNode *op) final {
// Check if the buffer is in global scope
if (IsGlobalBuffer(op->buffer)) {
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false);
}
CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false,
!IsGlobalBuffer(op->buffer));
if (recursively_collect_conds_) {
StmtExprVisitor::VisitStmt_(op);
}
Expand All @@ -70,7 +70,7 @@ struct GlobalMemChecker : public StmtExprVisitor {

// Check each index against the buffer shape dimensions
void CheckBufferIndices(const Buffer &buffer, const Array<PrimExpr> &indices,
bool is_load) {
bool is_load, bool throw_warning) {
// Ensure indices count matches buffer dimension
if (indices.size() != buffer->shape.size()) {
LOG(WARNING) << "Buffer access dimension mismatch: indices size ("
Expand Down Expand Up @@ -103,13 +103,24 @@ struct GlobalMemChecker : public StmtExprVisitor {
PrimExpr upper_bound_cond = index < shape_dim;
if (!analyzer_->CanProve(upper_bound_cond,
arith::ProofStrength::kSymbolicBound)) {
_conditions.push_back(upper_bound_cond);
if (throw_warning) {
LOG(WARNING) << "Index access may exceed buffer bounds: " << index
<< " >= " << shape_dim
<< "; Buffer name: " << buffer->name;
} else {
_conditions.push_back(upper_bound_cond);
}
}
// Check if index >= 0 can be proven.
PrimExpr lower_bound_cond = index >= 0;
if (!analyzer_->CanProve(lower_bound_cond,
arith::ProofStrength::kSymbolicBound)) {
_conditions.push_back(lower_bound_cond);
if (throw_warning) {
LOG(WARNING) << "Index access may be negative: " << index << " < 0"
<< "; Buffer name: " << buffer->name;
} else {
_conditions.push_back(lower_bound_cond);
}
}
}
}
Expand Down Expand Up @@ -150,7 +161,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {

// For Load/Store, we only check the current node, not its children.
// Since rewriter will recursively visit children.
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(load);
Array<PrimExpr> conditions = checker.GetConditions();

Expand All @@ -173,7 +184,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
// Check if the buffer is in global scope
auto store = Downcast<BufferStore>(IRMutatorWithAnalyzer::VisitStmt_(op));

GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false);
checker(store);
Array<PrimExpr> conditions = checker.GetConditions();

Expand Down Expand Up @@ -234,7 +245,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer {
// For CallExtern and atomic ops, we recursively collect conditions
// from all children. Since we cannot rewrite any BufferLoad in its
// children (Rewrite will cause potential Nullptr exception).
GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true);
SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/true);
checker(call);
Array<PrimExpr> conditions = checker.GetConditions();

Expand Down
Loading