Skip to content

Commit cdc2303

Browse files
authored
[TIR] Require exactly same-dtype matching for Vulkan smem reuse (#16515)
This PR fixes the StorageRewrite pass which failed to avoid shared memory reuse of different dtypes for Vulkan. Since the Vulkan target information is required at the time of lowering, the pass `BindTarget` needs to apply before lowering, so that the functions have correct target information. Note that previously the pass checks `Target::Current`, while `tvm.build` does not set the current target. One regression test is added.
1 parent a3ec544 commit cdc2303

File tree

3 files changed

+115
-20
lines changed

3 files changed

+115
-20
lines changed

src/tir/transforms/merge_shared_memory_allocations.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -662,11 +662,6 @@ namespace transform {
662662
Pass MergeSharedMemoryAllocations() {
663663
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
664664
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
665-
// disable this pass for Vulkan
666-
auto target = Target::Current(true);
667-
if (target.defined() && target->kind->name == "vulkan") {
668-
return f;
669-
}
670665
auto* n = f.CopyOnWrite();
671666
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
672667
return f;

src/tir/transforms/storage_rewrite.cc

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator {
380380
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
381381
using AllocEntry = LinearAccessPatternFinder::AllocEntry;
382382

383-
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
383+
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
384+
bool reuse_require_exact_matched_dtype) {
384385
detect_inplace_ = detect_inplace;
385386
// plan the rewrite
386387
LinearAccessPatternFinder finder;
387388
finder(stmt);
388389
this->LivenessAnalysis(finder.linear_seq_);
389-
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
390+
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
391+
reuse_require_exact_matched_dtype);
390392
all_buffers_accessed_ = finder.all_buffers_accessed_;
391393
this->PrepareNewAlloc();
392394
// start rewrite
@@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator {
817819
// Memory plan algorithm
818820
void PlanMemory(const std::vector<StmtEntry>& seq,
819821
const std::unordered_map<const VarNode*, AllocEntry>& alloc_info,
820-
bool enable_reuse = true) {
822+
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
821823
std::unordered_set<const VarNode*> inplace_flag;
822824

823825
for (size_t i = 0; i < seq.size(); ++i) {
@@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator {
864866
}
865867
}
866868
if (dst_entry == nullptr) {
867-
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
868-
entry.num_physical_dimensions, enable_reuse);
869+
dst_entry =
870+
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions,
871+
enable_reuse, reuse_require_exact_matched_dtype);
869872
}
870873
dst_entry->allocs.emplace_back(alloc);
871874
alloc_map_[var] = dst_entry;
@@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator {
919922

920923
StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
921924
const StorageScope& scope, size_t num_physical_dimensions,
922-
bool enable_reuse = true) {
925+
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
923926
ICHECK(op != nullptr);
924927
// skip plan for local variable,
925928
// compiler can do a better job with register allocation.
@@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator {
958961
if (e->scope != scope) continue;
959962
// when not divided, no reuse, eg, float4 vs float3
960963
if (e->bits_offset % op_elem_bits != 0) continue;
964+
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
965+
continue;
966+
}
961967
e->const_nbits = std::max(const_nbits, e->const_nbits);
962968
const_free_map_.erase(it);
963969
return e;
@@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator {
969975
if (e->attach_scope_ != attach_scope) continue;
970976
if (e->scope != scope) continue;
971977
if (e->elem_type != op->dtype.element_of()) continue;
978+
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
979+
continue;
980+
}
972981
e->const_nbits = std::max(const_nbits, e->const_nbits);
973982
const_free_map_.erase(it);
974983
return e;
@@ -1704,17 +1713,24 @@ namespace transform {
17041713

17051714
Pass StorageRewrite() {
17061715
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
1716+
bool enable_reuse = true;
1717+
bool reuse_require_exact_matched_dtype = false;
17071718
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
1708-
// disable merge_static_smem for Vulkan
1709-
auto target = Target::Current(true);
1710-
if (target.defined() && target->kind->name == "vulkan") {
1711-
merge_static_smem = false;
1712-
}
1713-
// Only enable reuse when we are not merging static shared memory.
1714-
// Otherwise we will do it in a separate stage
1715-
bool enable_reuse = merge_static_smem ? false : true;
1719+
if (merge_static_smem) {
1720+
// When `merge_static_smem` is true, we will reuse and merge shared
1721+
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
1722+
// And so we don't enable reuse in this pass.
1723+
enable_reuse = false;
1724+
}
1725+
1726+
Optional<Target> target = f->GetAttr<Target>("target");
1727+
if (target.defined() && target.value()->kind->name == "vulkan") {
1728+
// Require exactly same-dtype matching in smem reuse for Vulkan
1729+
reuse_require_exact_matched_dtype = true;
1730+
}
17161731
auto* n = f.CopyOnWrite();
1717-
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse);
1732+
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
1733+
reuse_require_exact_matched_dtype);
17181734
// Parameters may not be rewritten, but internal allocations may.
17191735
// Vectorization of AllocateConst is currently disabled, as it has
17201736
// indexing issues for types that include padding (e.g. int8x3

tests/python/tir-transform/test_tir_transform_storage_rewrite.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import sys
18+
1819
import pytest
20+
1921
import tvm
2022
import tvm.testing
2123
from tvm import te
@@ -928,5 +930,87 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
928930
D[i] = C[i]
929931

930932

933+
def test_vulkan_smem_reuse():
934+
target = tvm.target.Target(
935+
{
936+
"keys": ["vulkan", "gpu"],
937+
"kind": "vulkan",
938+
"max_num_threads": 256,
939+
"max_threads_per_block": 256,
940+
"supports_float32": T.bool(True),
941+
"supports_int32": T.bool(True),
942+
"tag": "",
943+
"thread_warp_size": 1,
944+
}
945+
)
946+
947+
@T.prim_func(private=True)
948+
def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
949+
T.func_attr({"tir.noalias": T.bool(True)})
950+
A_shared = T.allocate([4], "float32", "shared")
951+
A_local = T.allocate([4], "float32", "local")
952+
B_shared = T.allocate([4], "float16", "shared")
953+
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
954+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
955+
A_1 = T.Buffer((4,), data=A.data)
956+
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
957+
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
958+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
959+
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
960+
B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
961+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
962+
B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
963+
threadIdx_x = T.launch_thread("threadIdx.x", 4)
964+
B_1 = T.Buffer((4,), "float16", data=B.data)
965+
B_1[threadIdx_x] = B_shared_1[threadIdx_x]
966+
967+
@T.prim_func(private=True)
968+
def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
969+
T.func_attr({"tir.noalias": T.bool(True)})
970+
A_shared = T.allocate([4], "float32", "shared")
971+
A_local = T.allocate([4], "float32", "local")
972+
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
973+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
974+
A_1 = T.Buffer((4,), data=A.data)
975+
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
976+
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
977+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
978+
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
979+
A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
980+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
981+
A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
982+
threadIdx_x = T.launch_thread("threadIdx.x", 4)
983+
B_1 = T.Buffer((4,), "float16", data=B.data)
984+
B_1[threadIdx_x] = A_shared_2[threadIdx_x]
985+
986+
@T.prim_func(private=True)
987+
def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
988+
T.func_attr({"target": target, "tir.noalias": T.bool(True)})
989+
A_shared_1 = T.allocate([4], "float32", "shared")
990+
A_local_1 = T.allocate([4], "float32", "local")
991+
B_shared_1 = T.allocate([4], "float16", "shared")
992+
A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
993+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
994+
A_1 = T.Buffer((4,), data=A.data)
995+
A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
996+
A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
997+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
998+
A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
999+
B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared")
1000+
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
1001+
B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x])
1002+
threadIdx_x = T.launch_thread("threadIdx.x", 4)
1003+
B_1 = T.Buffer((4,), "float16", data=B.data)
1004+
B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]
1005+
1006+
# Reuse shared memory when lowering without target.
1007+
mod = tvm.IRModule({"main": func})
1008+
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)
1009+
1010+
# No shared memory reuse when lowering with target Vulkan.
1011+
mod = tvm.tir.transform.BindTarget(target)(mod)
1012+
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)
1013+
1014+
9311015
if __name__ == "__main__":
9321016
tvm.testing.main()

0 commit comments

Comments
 (0)