diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index 2838c5b8b..4e49c7aab 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -24,24 +24,25 @@ namespace tl { using namespace tir; using arith::IRMutatorWithAnalyzer; -// GlobalMemChecker for a BufferLoad/BufferStore node: +// SafeMemChecker for a BufferLoad/BufferStore node: // 1. Identify BufferLoad and BufferStore nodes. -// 2. Check if the buffer is in global scope. -// 3. For each index, compare against the buffer's shape. +// 2. For each index, compare against the buffer's shape. // If the index might exceed the shape (upper bound too large), -// log a warning or handle accordingly. -struct GlobalMemChecker : public StmtExprVisitor { +// log a warning (local/shared) or handle accordingly (global). +struct SafeMemChecker : public StmtExprVisitor { - GlobalMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds) + SafeMemChecker(arith::Analyzer *analyzer, bool recursively_collect_conds) : analyzer_(analyzer), recursively_collect_conds_(recursively_collect_conds) {} void VisitExpr_(const BufferLoadNode *op) final { - // Check if the buffer is in global scope - // This is because we are writing TilePrograms, where out of bounds - // accesses only happen in the global buffer. - if (IsGlobalBuffer(op->buffer)) { - CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true); - } + // If the buffer is in global scope, we will check its indices and add + // corresponding bound checks. + // If the buffer is in shared/local, although out of bound accesses are + // still possible, we assume the developers can handle them. This is because + // we are writing TilePrograms. Therefore we only log warnings if there + // are possible out-of-bounds. + CheckBufferIndices(op->buffer, op->indices, /*is_load=*/true, + !IsGlobalBuffer(op->buffer)); if (recursively_collect_conds_) { StmtExprVisitor::VisitExpr_(op); } @@ -49,9 +50,8 @@ struct GlobalMemChecker : public StmtExprVisitor { void VisitStmt_(const BufferStoreNode *op) final { // Check if the buffer is in global scope - if (IsGlobalBuffer(op->buffer)) { - CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false); - } + CheckBufferIndices(op->buffer, op->indices, /*is_load=*/false, + !IsGlobalBuffer(op->buffer)); if (recursively_collect_conds_) { StmtExprVisitor::VisitStmt_(op); } @@ -70,7 +70,7 @@ struct GlobalMemChecker : public StmtExprVisitor { // Check each index against the buffer shape dimensions void CheckBufferIndices(const Buffer &buffer, const Array &indices, - bool is_load) { + bool is_load, bool throw_warning) { // Ensure indices count matches buffer dimension if (indices.size() != buffer->shape.size()) { LOG(WARNING) << "Buffer access dimension mismatch: indices size (" @@ -103,13 +103,24 @@ struct GlobalMemChecker : public StmtExprVisitor { PrimExpr upper_bound_cond = index < shape_dim; if (!analyzer_->CanProve(upper_bound_cond, arith::ProofStrength::kSymbolicBound)) { - _conditions.push_back(upper_bound_cond); + if (throw_warning) { + LOG(WARNING) << "Index access may exceed buffer bounds: " << index + << " >= " << shape_dim + << "; Buffer name: " << buffer->name; + } else { + _conditions.push_back(upper_bound_cond); + } } // Check if index >= 0 can be proven. PrimExpr lower_bound_cond = index >= 0; if (!analyzer_->CanProve(lower_bound_cond, arith::ProofStrength::kSymbolicBound)) { - _conditions.push_back(lower_bound_cond); + if (throw_warning) { + LOG(WARNING) << "Index access may be negative: " << index << " < 0" + << "; Buffer name: " << buffer->name; + } else { + _conditions.push_back(lower_bound_cond); + } } } } @@ -150,7 +161,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { // For Load/Store, we only check the current node, not its children. // Since rewriter will recursively visit children. - GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); checker(load); Array conditions = checker.GetConditions(); @@ -173,7 +184,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { // Check if the buffer is in global scope auto store = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); - GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); + SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/false); checker(store); Array conditions = checker.GetConditions(); @@ -234,7 +245,7 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { // For CallExtern and atomic ops, we recursively collect conditions // from all children. Since we cannot rewrite any BufferLoad in its // children (Rewrite will cause potential Nullptr exception). - GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true); + SafeMemChecker checker(analyzer_, /*recursively_collect_conds=*/true); checker(call); Array conditions = checker.GetConditions();