diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index f5fdd191af87..45e44f96a633 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -614,9 +614,9 @@ TVM_DLL Pass InstallDebugSpans(); TVM_DLL Pass UnifyThreadBinding(); /*! - * A pass to merge multiple TIR-level dynamic shared memory allocations into one + * A pass to merge multiple TIR-level shared memory allocations into one */ -TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); +TVM_DLL Pass MergeSharedMemoryAllocations(); /*! * \brief This pass is post-scheduling pass to convert all diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c58062045c64..ba6b5974e27f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1000,8 +1000,8 @@ def UnifyThreadBinding(): return _ffi_api.UnifyThreadBinding() # type: ignore -def MergeDynamicSharedMemoryAllocations(): - """This pass merges multiple TIR-level dynamic shared memory allocations +def MergeSharedMemoryAllocations(): + """This pass merges multiple TIR-level shared memory allocations into one allocation. Returns @@ -1009,7 +1009,7 @@ def MergeDynamicSharedMemoryAllocations(): fpass : tvm.transform.Pass The result pass """ - return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore + return _ffi_api.MergeSharedMemoryAllocations() # type: ignore def ConvertForLoopsToSerial(): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b7ba0ffe4468..17cd5c49a1bf 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -52,6 +52,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); @@ -584,7 +585,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); - mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 2fb97d32eb74..17283063543d 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -179,7 +179,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::InjectVirtualThread()); pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::StorageRewrite()); - pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); + pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); pass_list.push_back(tir::transform::LowerIntrin()); // Convert Function to IRModule transform::PassContext pass_ctx = transform::PassContext::Current(); diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc similarity index 84% rename from src/tir/transforms/merge_dynamic_shared_memory_allocations.cc rename to src/tir/transforms/merge_shared_memory_allocations.cc index 99055cebf2dc..75396407352c 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -18,9 +18,10 @@ */ /*! - * \file merge_dynamic_shared_memory_allocations.cc - * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. - * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. + * \file merge_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic or static shared memory allocation. + * This pass merges multiple TIR-level dynamic or static shared memory allocations into one + * allocation. */ #include #include @@ -45,6 +46,11 @@ bool IsDynamicSharedMemory(Var buffer_var) { return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; } +bool IsStaticSharedMemory(Var buffer_var) { + StorageScope storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ""; +} + /*! * \brief collect the mapping from the buffer var to its allocate */ @@ -53,11 +59,15 @@ class AllocateCollector : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { if (IsDynamicSharedMemory(op->buffer_var)) { dyn_shmem_allocs_[op->buffer_var.get()] = op; + } else if (IsStaticSharedMemory(op->buffer_var)) { + static_shmem_allocs_[op->buffer_var.get()] = op; } StmtExprVisitor::VisitStmt_(op); } - // The mapping from the original buffer var to its allocate + // The dynamic mapping from the original buffer var to its allocate std::unordered_map dyn_shmem_allocs_; + // The static mapping from the original buffer var to its allocate + std::unordered_map static_shmem_allocs_; }; // Find a linear pattern of storage access @@ -73,8 +83,9 @@ class AllocateCollector : public StmtExprVisitor { // The storage need to be kept alive between Allocate and last access. // The free point is only inserted at the same scope of Allocate. // -class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { +class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { public: + explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true) : is_dynamic_(is_dynamic) {} /*! \brief record the touch list of statement. */ struct StmtEntry { // The statement @@ -112,7 +123,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsDynamicSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -143,7 +154,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - if (IsDynamicSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -164,7 +175,7 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { ICHECK_LT(it->second.level, scope_.size()); - if (IsDynamicSharedMemory(GetRef(buf))) { + if (IsAppropriateSharedMemory(GetRef(buf))) { scope_[it->second.level].touched.push_back(buf); } } @@ -217,6 +228,12 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { std::unordered_map alloc_info_; private: + // Wrapper function to determine if the shared memory allocation for a variable is appropriate. + bool IsAppropriateSharedMemory(const Var& var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + // Whether do dyanmic analysis. + bool is_dynamic_{true}; // Whether already in thread env. bool in_thread_env_{false}; // The scope stack. @@ -226,18 +243,23 @@ class DynSharedMemLinearAccessPatternFinder final : public StmtExprVisitor { /*! * \brief merge the buffers whose live range has no intersection and rewrite the body */ -class DynamicSharedMemoryRewriter : public StmtExprMutator { +class SharedMemoryRewriter : public StmtExprMutator { public: - explicit DynamicSharedMemoryRewriter( - const std::unordered_map& dyn_shmem_allocs) - : dyn_shmem_allocs_{dyn_shmem_allocs} {} + explicit SharedMemoryRewriter( + const std::unordered_map& shmem_allocs, + bool is_dynamic = true) + : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} { + if (!is_dynamic) { + merged_buf_var_ = Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), "shared")); + } + } /*! * \brief plan the memory reuse for all the buffer allocated in the statement * \param stmt the statement */ - void PlanReuse(const Stmt& stmt) { - DynSharedMemLinearAccessPatternFinder finder; + void PlanReuse(const Stmt& stmt, bool is_dynamic = true) { + SharedMemLinearAccessPatternFinder finder(is_dynamic); finder(stmt); this->LivenessAnalysis(finder.linear_seq_); this->PlanMemory(finder.linear_seq_); @@ -263,7 +285,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { for (const StorageEntry* e : all_entry) { for (int i = 0; i < static_cast(e->allocs.size()); i++) { for (const VarNode* buffer : e->allocs[i]) { - const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + const AllocateNode* alloc = shmem_allocs_[buffer]; align[i] = std::max(align[i], alloc->dtype.bytes()); } } @@ -274,7 +296,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { for (int i = 0; i < static_cast(e->allocs.size()); i++) { PrimExpr inner_offset = 0; for (const VarNode* buffer : e->allocs[i]) { - const AllocateNode* alloc = dyn_shmem_allocs_[buffer]; + const AllocateNode* alloc = shmem_allocs_[buffer]; buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset; inner_offset += alloc->extents[0] * alloc->dtype.bytes(); inner_offset += indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]); @@ -293,7 +315,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AllocateNode* op) final { - if (IsDynamicSharedMemory(op->buffer_var)) { + if (IsAppropriateSharedMemory(op->buffer_var)) { return StmtExprMutator::VisitStmt(op->body); } return StmtExprMutator::VisitStmt_(op); @@ -319,9 +341,9 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { template Node VisitBufferAccess(Node node) { - if (IsDynamicSharedMemory(node->buffer->data)) { + if (IsAppropriateSharedMemory(node->buffer->data)) { ICHECK_EQ(node->indices.size(), 1) - << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; Array indices = {node->indices[0] + @@ -342,10 +364,10 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { return it->second; } - if (IsDynamicSharedMemory(buffer->data)) { + if (IsAppropriateSharedMemory(buffer->data)) { ICHECK_EQ(buffer->shape.size(), 1) << "Buffer " << buffer << " has shape " << buffer->shape << ". " - << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, " + << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; auto writer = buffer.CopyOnWrite(); @@ -361,7 +383,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); Var buffer = Downcast(op->args[1]); - if (!IsDynamicSharedMemory(buffer)) { + if (!IsAppropriateSharedMemory(buffer)) { return StmtExprMutator::VisitExpr_(op); } PrimExpr extra_offset = GetBufferOffset(buffer, dtype); @@ -381,7 +403,12 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { return indexdiv(it->second, dtype.bytes()); } - using StmtEntry = DynSharedMemLinearAccessPatternFinder::StmtEntry; + // Wrapper function to determine if the shared memory allocation for a variable is appropriate. + bool IsAppropriateSharedMemory(const Var& var) { + return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); + } + + using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; struct StorageEntry { // The constant size of the buffer in bits, only used if it is constant uint64_t const_nbits{0}; @@ -458,8 +485,8 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // In both cases, we need to handle the gen event correctly if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { for (const VarNode* var : it->second.gen) { - ICHECK(dyn_shmem_allocs_.count(var)); - const AllocateNode* alloc = dyn_shmem_allocs_[var]; + ICHECK(shmem_allocs_.count(var)); + const AllocateNode* alloc = shmem_allocs_[var]; StorageEntry* dst_entry = FindAlloc(alloc); alloc_map_[var] = dst_entry; } @@ -578,10 +605,12 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { sym_free_list_.push_back(e); } } + // Wheather enable dyanmic analysis. + bool is_dynamic_{true}; // The var for the merged buffer Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; // The mapping from the original buffer var to its allocate - std::unordered_map dyn_shmem_allocs_; + std::unordered_map shmem_allocs_; // The size of the merged buffer PrimExpr merged_alloc_size_{0}; // The mapping from the original buffer var to its offset in the merged buffer @@ -602,30 +631,36 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { support::Arena arena_; }; -Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { +Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) { AllocateCollector collector; collector(stmt); if (collector.dyn_shmem_allocs_.size() > 1) { - DynamicSharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); + SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_); rewriter.PlanReuse(stmt); - return rewriter(std::move(stmt)); + stmt = rewriter(std::move(stmt)); + } + if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) { + SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false); + rewriter.PlanReuse(stmt, false); + stmt = rewriter(std::move(stmt)); } return stmt; } namespace transform { -Pass MergeDynamicSharedMemoryAllocations() { +Pass MergeSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); auto* n = f.CopyOnWrite(); - n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.MergeSharedMemoryAllocations", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations") - .set_body_typed(MergeDynamicSharedMemoryAllocations); +TVM_REGISTER_GLOBAL("tir.transform.MergeSharedMemoryAllocations") + .set_body_typed(MergeSharedMemoryAllocations); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index f271769c804b..70f325e4a21e 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -380,13 +380,13 @@ class StoragePlanRewriter : public StmtExprMutator { using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; - Stmt Rewrite(Stmt stmt, bool detect_inplace) { + Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; finder(stmt); this->LivenessAnalysis(finder.linear_seq_); - this->PlanMemory(finder.linear_seq_, finder.alloc_info_); + this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse); all_buffers_accessed_ = finder.all_buffers_accessed_; this->PrepareNewAlloc(); // start rewrite @@ -816,7 +816,8 @@ class StoragePlanRewriter : public StmtExprMutator { // Memory plan algorithm void PlanMemory(const std::vector& seq, - const std::unordered_map& alloc_info) { + const std::unordered_map& alloc_info, + bool enable_reuse = true) { std::unordered_set inplace_flag; for (size_t i = 0; i < seq.size(); ++i) { @@ -863,8 +864,8 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = - FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope, + entry.num_physical_dimensions, enable_reuse); } dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; @@ -917,7 +918,8 @@ class StoragePlanRewriter : public StmtExprMutator { } StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, - const StorageScope& scope, size_t num_physical_dimensions) { + const StorageScope& scope, size_t num_physical_dimensions, + bool enable_reuse = true) { ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. @@ -940,7 +942,7 @@ class StoragePlanRewriter : public StmtExprMutator { (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() || (is_known_size && const_nbits <= 32)); - if (is_small_array || !is_flat_memory_space) { + if (!enable_reuse || is_small_array || !is_flat_memory_space) { return NewAlloc(op, attach_scope, scope, const_nbits); } @@ -1702,8 +1704,9 @@ namespace transform { Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); auto* n = f.CopyOnWrite(); - n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, !merge_static_smem); // Parameters may not be rewritten, but internal allocations may. // Vectorization of AllocateConst is currently disabled, as it has // indexing issues for types that include padding (e.g. int8x3 diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index bf68200944d4..c52aca767410 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -161,7 +161,7 @@ def test_inject_async_copy_shared_dyn(): mod = tvm.tir.transform.LowerOpaqueBlock()(mod) mod = tvm.tir.transform.FlattenBuffer()(mod) mod = tvm.tir.transform.VectorizeLoop()(mod) - mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod) + mod = tvm.tir.transform.MergeSharedMemoryAllocations()(mod) mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod) assert count_cp_async(mod["main"].body) == 2 diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 37372059a296..1eb9cd97cfd7 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -32,7 +32,7 @@ def run_passes(sch, args): tvm.tir.transform.Simplify(), tvm.tir.transform.VectorizeLoop(), tvm.tir.transform.StorageRewrite(), - tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + tvm.tir.transform.MergeSharedMemoryAllocations(), ] )(mod) @@ -336,7 +336,7 @@ class TestMatmul(tvm.testing.CompareBeforeAfter): for the replaced allocations. """ - transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + transform = tvm.tir.transform.MergeSharedMemoryAllocations() use_decl_buffer = tvm.testing.parameter(by_dict={"t_buffer": False, "decl_buffer": True}) diff --git a/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py new file mode 100644 index 000000000000..be32514a720c --- /dev/null +++ b/tests/python/tir-transform/test_tir_transform_merge_static_shared_memory_allocations.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.driver.build_module import schedule_to_module +from tvm.topi.math import cast +from tvm.script import tir as T + + +def run_passes(sch, args): + mod = schedule_to_module(sch, args) + with tvm.transform.PassContext(config={"tir.merge_static_smem": True}): + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.VectorizeLoop(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.MergeSharedMemoryAllocations(), + ] + )(mod) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +@tvm.testing.requires_gpu +def test_matmul_shared(): + n = 1024 + block = 16 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n // block) + ib.scope_attr(by, "thread_extent", n // block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared", name="B_sh") # fp16 + # Create a shared memory for the accumulation. + # This is for testing merging shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_local = ib.allocate(C.dtype, (1,), scope="local", name="C_local") + C_sh = ib.allocate(C.dtype, (block, block), scope="shared", name="C_sh") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_local[0] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_local[0] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + ib.emit(syncthread()) + + C_sh[ty, tx] = C_local[0] + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="matmul", + dtype="float32", + ) + s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + # C can be allocated at the start of A, so we only need to allocate 2 block * block memory with dtype = float16 + expected_alloc_size = block * block * 4 + verify_single_allocation(mod["main"].body, expected_alloc_size) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_shared_more_dtype(): + """Test vectorized store into shared memory""" + n = 512 + A = te.placeholder((n,), name="A", dtype="int8") + B = te.placeholder((n,), name="B", dtype="int16") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared") # i8 + B_sh = ib.allocate(B.dtype, (n,), scope="shared") # i16 + C_sh = ib.allocate(C.dtype, (n,), scope="shared") # i32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + A_sh[tx] = Aptr[tx] + B_sh[tx] = Bptr[tx] + + C_sh[tx] = cast(A_sh[tx], "int32") + cast(B_sh[tx], "int32") + Cptr[tx] = C_sh[tx] + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="int32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body, n * 4) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose(c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4) + + for target in ["cuda"]: + check_target(target) + + +if __name__ == "__main__": + tvm.testing.main()