Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include "../../runtime/thread_storage_scope.h"
#include "../../support/arena.h"

namespace tvm {
namespace tir {

/*!
* \brief Detect the lowest common ancestor(LCA) position of Buffer access.
* \note Only consider BlockNode and ForNode to be the LCA nodes.
* \note
* - Only consider BlockNode and ForNode to be the LCA nodes.
* - In the LCA locator, we are aware of the buffer scope and CUDA hierarchy so that any buffer in
* global memory will have its buffer access LCA outside all launch sites of `blockIdx`, in order to
* prevent conflicts between buffer memory scopes and CUDA hierarchy.
*/
class LCADetector : public StmtExprVisitor {
public:
Expand All @@ -51,6 +56,8 @@ class LCADetector : public StmtExprVisitor {
detector.ancestor_scopes_.push_back(&root);

detector(func->body);
detector.UpdateWithBlockidx();

// Prepare the return
Map<Buffer, Optional<Stmt>> buffer_lca;
for (const auto& kv : detector.buffer_lca_) {
Expand Down Expand Up @@ -82,6 +89,15 @@ class LCADetector : public StmtExprVisitor {
int n = ancestor_scopes_.size();
const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);

if (op->thread_binding.defined()) {
const runtime::ThreadScope& scope =
runtime::ThreadScope::Create(op->thread_binding.value()->thread_tag);
if (scope.rank == 0) {
blockidx_scopes_.push_back(current_scope);
}
}

ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
Expand All @@ -107,6 +123,18 @@ class LCADetector : public StmtExprVisitor {
ancestor_scopes_.pop_back();
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
const auto* iter = op->node.as<IterVarNode>();
ICHECK_NOTNULL(iter);
const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag);
if (scope.rank == 0) {
blockidx_scopes_.push_back(ancestor_scopes_.back());
}
}
StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const BufferLoadNode* op) final {
UpdateBufferLCA(op->buffer.get());
StmtExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -150,6 +178,19 @@ class LCADetector : public StmtExprVisitor {
}
}

void UpdateWithBlockidx() {
for (const auto& it : buffer_lca_) {
const runtime::StorageScope& scope =
runtime::StorageScope::Create(GetRef<Buffer>(it.first).scope());
if (scope.rank == runtime::StorageRank::kGlobal) {
const ScopeInfo*& lca = buffer_lca_[it.first];
for (const ScopeInfo* blockidx_scope : blockidx_scopes_) {
lca = LowestCommonAncestor(lca, blockidx_scope);
}
}
}
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
if (lhs == nullptr) return rhs;
if (rhs == nullptr) return lhs;
Expand Down Expand Up @@ -186,6 +227,8 @@ class LCADetector : public StmtExprVisitor {
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief The match buffers inside blocks. */
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,19 @@ def match_buffer_func(a: T.handle, b: T.handle) -> None:
T.evaluate(B1.data)


@T.prim_func
def global_buffer_with_blockidx(
a: T.Buffer[(1, 32), "int32"], b: T.Buffer[(1, 32), "int32"]
) -> None:
for i0 in T.thread_binding(0, 1, thread="blockIdx.x"):
for i1 in T.thread_binding(0, 32, thread="threadIdx.x"):
with T.block("copy"):
i, j = T.axis.remap("SS", [i0, i1])
T.reads(a[i, j])
T.writes(b[i, j])
b[i, j] = a[i, j]


def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
Expand Down Expand Up @@ -154,8 +167,21 @@ def test_match_buffer():
assert lca[B] == block


def test_global_buffer_with_blockidx():
func = global_buffer_with_blockidx
A, B = [func.buffer_map[x] for x in func.params]
lca = tir.analysis.detect_buffer_access_lca(func)

root_block = func.body.block
blockidx_loop = root_block.body
# LCA of both A and B should be the loop bound to `blockIdx`
assert lca[A] == blockidx_loop
assert lca[B] == blockidx_loop


if __name__ == "__main__":
test_buffer_load_store()
test_opaque_access()
test_lca_func_root()
test_match_buffer()
test_global_buffer_with_blockidx()