Skip to content

Commit 534205b

Browse files
authored
[TIR] Check dynamic shared memory in VerifyGPUCode (apache#10923)
1 parent 27c5910 commit 534205b

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

src/tir/analysis/verify_gpu_code.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include <tvm/tir/stmt.h>
3131
#include <tvm/tir/stmt_functor.h>
3232

33+
#include "../../runtime/thread_storage_scope.h"
3334
#include "../transforms/ir_utils.h"
3435

3536
namespace tvm {
@@ -61,11 +62,12 @@ class GPUCodeVerifier : public StmtExprVisitor {
6162
void VisitStmt_(const AllocateNode* op) final {
6263
StmtVisitor::VisitStmt_(op);
6364
auto scope = GetPtrStorageScope(op->buffer_var);
65+
runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope);
6466
// visit an allocation of a buffer in shared memory, record its size
65-
if (scope == "local") {
67+
if (storage_scope.rank == runtime::StorageRank::kLocal) {
6668
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
6769
local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
68-
} else if (scope == "shared") {
70+
} else if (storage_scope.rank == runtime::StorageRank::kShared) {
6971
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
7072
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
7173
}

tests/python/unittest/test_tir_analysis_verify_gpu_code.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _fverify(f, *_):
3232

3333
@tvm.testing.requires_gpu
3434
def test_shared_memory():
35-
def check_shared_memory(dtype):
35+
def check_shared_memory(storage_scope, dtype):
3636
N = 1024
3737
M = 128
3838

@@ -43,7 +43,7 @@ def check_shared_memory(dtype):
4343
B = te.compute((N,), lambda i: A[i], name="B")
4444

4545
s = te.create_schedule([B.op])
46-
AA = s.cache_read(A, "shared", [B])
46+
AA = s.cache_read(A, storage_scope, [B])
4747
o, i = s[B].split(s[B].op.axis[0], M)
4848
s[AA].compute_at(s[B], o)
4949
s[B].bind(o, te.thread_axis("blockIdx.x"))
@@ -90,8 +90,9 @@ def check_shared_memory(dtype):
9090
tvm.build(s, [A, B], target)
9191
assert valid[0]
9292

93-
check_shared_memory("float32")
94-
check_shared_memory("int8x4")
93+
check_shared_memory("shared", "float32")
94+
check_shared_memory("shared", "int8x4")
95+
check_shared_memory("shared.dyn", "float32")
9596

9697

9798
@tvm.testing.requires_gpu

0 commit comments

Comments
 (0)