diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index b082581a5148..2fec88d39dbc 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -30,6 +30,7 @@ #include #include +#include "../../runtime/thread_storage_scope.h" #include "../transforms/ir_utils.h" namespace tvm { @@ -61,11 +62,12 @@ class GPUCodeVerifier : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); auto scope = GetPtrStorageScope(op->buffer_var); + runtime::StorageScope storage_scope = runtime::StorageScope::Create(scope); // visit an allocation of a buffer in shared memory, record its size - if (scope == "local") { + if (storage_scope.rank == runtime::StorageRank::kLocal) { size_t size = static_cast(op->ConstantAllocationSize()); local_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); - } else if (scope == "shared") { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index b7d78aad140d..33e93447a9f1 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -32,7 +32,7 @@ def _fverify(f, *_): @tvm.testing.requires_gpu def test_shared_memory(): - def check_shared_memory(dtype): + def check_shared_memory(storage_scope, dtype): N = 1024 M = 128 @@ -43,7 +43,7 @@ def check_shared_memory(dtype): B = te.compute((N,), lambda i: A[i], name="B") s = te.create_schedule([B.op]) - AA = s.cache_read(A, "shared", [B]) + AA = s.cache_read(A, storage_scope, [B]) o, i = s[B].split(s[B].op.axis[0], M) s[AA].compute_at(s[B], o) s[B].bind(o, te.thread_axis("blockIdx.x")) @@ -90,8 +90,9 @@ def check_shared_memory(dtype): tvm.build(s, [A, B], target) assert valid[0] - check_shared_memory("float32") - check_shared_memory("int8x4") + check_shared_memory("shared", "float32") + check_shared_memory("shared", "int8x4") + check_shared_memory("shared.dyn", "float32") @tvm.testing.requires_gpu