From c236e16052b5a6c5262030b406288a9d0d041a7e Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 3 Aug 2023 23:30:17 -0400 Subject: [PATCH] [Cherry-Pick][BugFix][TIR] ThreadSync with shared.dyn awareness Cherry-picked from #15478. This PR fixes an issue of the ThreadSync pass. Prior to this PR, the pass is not aware of `shared.dyn` scope whose users all share a same shared memory space. This feature is not necessarily already revealed in the IR at the time of applying ThreadSync. This means that when applying ThreadSync, in the IR, each buffer of `shared.dyn` scope still uses its own data Var, and ThreadSync is thus unable to detect the conflict properly and insert the sync instructions properly. This PR explicitly makes ThreadSync be aware of the `shared.dyn` scope, and redirect all the access vars of `shared.dyn` memory to a common var, so that ThreadSync analysis can find out the conflict and insert the sync instructions. --- src/tir/transforms/thread_storage_sync.cc | 18 +++++++- .../test_tir_transform_thread_sync.py | 42 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index c21afe400c56..d92986e51a9c 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -50,11 +50,27 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } // Plan the sync std::vector Summarize(std::vector seq, const ForNode* loop) final { + // Redirect all "shared.dyn" buffer access to the same buffer var + // so that the accesses can be planned together. + Var shared_dyn_buf; + for (StmtEntry& entry : seq) { + for (AccessEntry& access : entry.access) { + if (access.scope.rank == StorageRank::kShared && access.scope.tag == ".dyn" && + access.buffer.defined()) { + if (!shared_dyn_buf.defined()) { + shared_dyn_buf = access.buffer; + } else { + access.buffer = shared_dyn_buf; + } + } + } + } + // Unsynced reads and writes std::vector reads; std::vector writes; // if it is a loop, rotate two times to consider effect of loop. - // simulation based approach to find dependenceies + // simulation based approach to find dependencies for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; // check if sync before statement is needed. diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 57ea223cf984..571927dffe6e 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -119,7 +119,49 @@ def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) assert "T.tvm_storage_sync" in str(mod) +def test_sync_shared_dyn(): + @T.prim_func(private=True) + def func(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B = T.allocate([24], "float32", "shared.dyn") + C = T.allocate([1], "float32", "local") + D = T.allocate([16], "float32", "shared.dyn") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1 = T.Buffer((24,), data=B, scope="shared.dyn") + A_1 = T.Buffer((16,), data=A.data) + B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1 = T.Buffer((1,), data=C, scope="local") + C_1[0] = B_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + D_1 = T.Buffer((16,), data=D, scope="shared.dyn") + D_1[threadIdx_x] = C_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1[threadIdx_x] + + @T.prim_func(private=True) + def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + B_1 = T.allocate([24], "float32", "shared.dyn") + C_1 = T.allocate([1], "float32", "local") + D_1 = T.allocate([16], "float32", "shared.dyn") + threadIdx_x = T.launch_thread("threadIdx.x", 16) + B_1_1 = T.Buffer((24,), data=B_1, scope="shared.dyn") + A_1 = T.Buffer((16,), data=A.data) + B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] = A_1[threadIdx_x] + C_1_1 = T.Buffer((1,), data=C_1, scope="local") + C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] + T.tvm_storage_sync("shared.dyn") + D_1_1 = T.Buffer((16,), data=D_1, scope="shared.dyn") + D_1_1[threadIdx_x] = C_1_1[0] + E_1 = T.Buffer((16,), data=E.data) + E_1[threadIdx_x] = D_1_1[threadIdx_x] + + mod = tvm.IRModule({"main": func}) + mod = tvm.tir.transform.ThreadSync("shared.dyn")(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + if __name__ == "__main__": test_thread_storage_sync() test_sync_else_branch() test_sync_read_thread_id_independent_location() + test_sync_shared_dyn()