Skip to content

Commit e92f5d4

Browse files
authored
[BugFix][TIR] Fix Buffer LCA Detector (#12819)
Prior to this PR, the LCA detector of buffers in TIR didn't take buffer memory scopes and GPU hierarchy into consideration. An consequent issue is that, when an intermediate buffer is in global memory, TIR's lowering passes don't necessarily allocated the intermediate buffer outside all `blockIdx`. As a result, the global intermediate buffer is allocated under a GPU thread block, which is illegal. This PR fixes this issue by fixing the LCA detector, making it be aware of the buffer memory scopes and GPU hierarchy. With this fix, the global intermediate buffers are all allocated outside `blockIdx`.
1 parent 91cce56 commit e92f5d4

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/tir/analysis/buffer_access_lca_detector.cc

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,19 @@
2525
#include <tvm/tir/analysis.h>
2626
#include <tvm/tir/stmt_functor.h>
2727

28+
#include "../../runtime/thread_storage_scope.h"
2829
#include "../../support/arena.h"
2930

3031
namespace tvm {
3132
namespace tir {
3233

3334
/*!
3435
* \brief Detect the lowest common ancestor(LCA) position of Buffer access.
35-
* \note Only consider BlockNode and ForNode to be the LCA nodes.
36+
* \note
37+
* - Only consider BlockNode and ForNode to be the LCA nodes.
38+
* - In the LCA locator, we are aware of the buffer scope and CUDA hierarchy so that any buffer in
39+
* global memory will have its buffer access LCA outside all launch sites of `blockIdx`, in order to
40+
* prevent conflicts between buffer memory scopes and CUDA hierarchy.
3641
*/
3742
class LCADetector : public StmtExprVisitor {
3843
public:
@@ -51,6 +56,8 @@ class LCADetector : public StmtExprVisitor {
5156
detector.ancestor_scopes_.push_back(&root);
5257

5358
detector(func->body);
59+
detector.UpdateWithBlockidx();
60+
5461
// Prepare the return
5562
Map<Buffer, Optional<Stmt>> buffer_lca;
5663
for (const auto& kv : detector.buffer_lca_) {
@@ -82,6 +89,15 @@ class LCADetector : public StmtExprVisitor {
8289
int n = ancestor_scopes_.size();
8390
const ScopeInfo* parent_scope = ancestor_scopes_.back();
8491
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
92+
93+
if (op->thread_binding.defined()) {
94+
const runtime::ThreadScope& scope =
95+
runtime::ThreadScope::Create(op->thread_binding.value()->thread_tag);
96+
if (scope.rank == 0) {
97+
blockidx_scopes_.push_back(current_scope);
98+
}
99+
}
100+
85101
ancestor_scopes_.push_back(current_scope);
86102
StmtExprVisitor::VisitStmt_(op);
87103
ancestor_scopes_.pop_back();
@@ -107,6 +123,18 @@ class LCADetector : public StmtExprVisitor {
107123
ancestor_scopes_.pop_back();
108124
}
109125

126+
void VisitStmt_(const AttrStmtNode* op) final {
127+
if (op->attr_key == attr::thread_extent) {
128+
const auto* iter = op->node.as<IterVarNode>();
129+
ICHECK_NOTNULL(iter);
130+
const runtime::ThreadScope& scope = runtime::ThreadScope::Create(iter->thread_tag);
131+
if (scope.rank == 0) {
132+
blockidx_scopes_.push_back(ancestor_scopes_.back());
133+
}
134+
}
135+
StmtExprVisitor::VisitStmt_(op);
136+
}
137+
110138
void VisitExpr_(const BufferLoadNode* op) final {
111139
UpdateBufferLCA(op->buffer.get());
112140
StmtExprVisitor::VisitExpr_(op);
@@ -150,6 +178,19 @@ class LCADetector : public StmtExprVisitor {
150178
}
151179
}
152180

181+
void UpdateWithBlockidx() {
182+
for (const auto& it : buffer_lca_) {
183+
const runtime::StorageScope& scope =
184+
runtime::StorageScope::Create(GetRef<Buffer>(it.first).scope());
185+
if (scope.rank == runtime::StorageRank::kGlobal) {
186+
const ScopeInfo*& lca = buffer_lca_[it.first];
187+
for (const ScopeInfo* blockidx_scope : blockidx_scopes_) {
188+
lca = LowestCommonAncestor(lca, blockidx_scope);
189+
}
190+
}
191+
}
192+
}
193+
153194
static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
154195
if (lhs == nullptr) return rhs;
155196
if (rhs == nullptr) return lhs;
@@ -186,6 +227,8 @@ class LCADetector : public StmtExprVisitor {
186227
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
187228
/*! \brief The match buffers inside blocks. */
188229
std::unordered_set<const BufferNode*> match_buffers_ = {};
230+
/*! \brief The ForNodes/BlockNodes which contain immediate `blockIdx` launch. */
231+
std::vector<const ScopeInfo*> blockidx_scopes_ = {};
189232
/*! \brief Internal arena. */
190233
support::Arena arena_;
191234
};

tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ def match_buffer_func(a: T.handle, b: T.handle) -> None:
9393
T.evaluate(B1.data)
9494

9595

96+
@T.prim_func
97+
def global_buffer_with_blockidx(
98+
a: T.Buffer[(1, 32), "int32"], b: T.Buffer[(1, 32), "int32"]
99+
) -> None:
100+
for i0 in T.thread_binding(0, 1, thread="blockIdx.x"):
101+
for i1 in T.thread_binding(0, 32, thread="threadIdx.x"):
102+
with T.block("copy"):
103+
i, j = T.axis.remap("SS", [i0, i1])
104+
T.reads(a[i, j])
105+
T.writes(b[i, j])
106+
b[i, j] = a[i, j]
107+
108+
96109
def test_buffer_load_store():
97110
func = buffer_load_store_func
98111
A, B = [func.buffer_map[x] for x in func.params]
@@ -154,8 +167,21 @@ def test_match_buffer():
154167
assert lca[B] == block
155168

156169

170+
def test_global_buffer_with_blockidx():
171+
func = global_buffer_with_blockidx
172+
A, B = [func.buffer_map[x] for x in func.params]
173+
lca = tir.analysis.detect_buffer_access_lca(func)
174+
175+
root_block = func.body.block
176+
blockidx_loop = root_block.body
177+
# LCA of both A and B should be the loop bound to `blockIdx`
178+
assert lca[A] == blockidx_loop
179+
assert lca[B] == blockidx_loop
180+
181+
157182
if __name__ == "__main__":
158183
test_buffer_load_store()
159184
test_opaque_access()
160185
test_lca_func_root()
161186
test_match_buffer()
187+
test_global_buffer_with_blockidx()

0 commit comments

Comments
 (0)