From 2f502b43b47ef0303a1e88d6f2f43101678124ba Mon Sep 17 00:00:00 2001 From: weitao <51255903105@stu.ecnu.edu.cn> Date: Wed, 27 Mar 2024 07:00:31 +0000 Subject: [PATCH 1/5] [BugTIR]fix error merging shared memory for ptx_cp_async --- .../merge_shared_memory_allocations.cc | 21 ++++++++++++++ ...merge_dynamic_shared_memory_allocations.py | 29 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index c79b9c1f9399..9eb3b5e62392 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -170,6 +170,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { StmtExprVisitor::VisitExpr_(op); } } + void VisitExpr_(const VarNode* buf) final { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); @@ -180,6 +181,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { } } } + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); @@ -200,6 +202,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { ICHECK_NE(end_index, 0U); linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } + void VisitStmt_(const AttrStmtNode* op) final { // Only record the outer most thread extent. if (op->attr_key == attr::thread_extent && !in_thread_env_) { @@ -214,6 +217,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } @@ -392,6 +396,23 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr extent = this->VisitExpr(op->args[3]); return Call(op->dtype, op->op, {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); + } else if (op->op.same_as(builtin::ptx_cp_async())) { + ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); + DataType dtype = op->args[0].dtype(); + Var buffer = Downcast(op->args[0]); + if (!IsAppropriateSharedMemory(buffer)) { + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr extra_offset = GetBufferOffset(buffer, dtype); + PrimExpr offset = this->VisitExpr(op->args[1]); + if (op->args.size() == 5) + return Call( + op->dtype, op->op, + {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4]}); + else + return Call(op->dtype, op->op, + {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4], + op->args[5]}); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 8661843d39c1..ef6a6288cb38 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -513,5 +513,34 @@ def func(): return func +class TestAsyncCopy(tvm.testing.CompareBeforeAfter): + """Test async copy in shared memory.""" + + transform = tvm.tir.transform.MergeSharedMemoryAllocations() + + def before(self): + @T.prim_func + def func(A: T.buffer((128)), B: T.buffer((128))): + A_sh_data = T.allocate([128], "float32", "shared.dyn") + B_sh_data = T.allocate([128], "float32", "shared.dyn") + A_sh = T.buffer([128], data=A_sh_data, scope="shared.dyn") + B_sh = T.buffer([128], data=B_sh_data, scope="shared.dyn") + threadIdx_x = T.launch_thread("threadIdx.x", 128) + T.ptx_cp_async("float32", A_sh.data, threadIdx_x, A.data, threadIdx_x, 512) + T.ptx_cp_async("float32", B_sh.data, threadIdx_x, B.data, threadIdx_x, 512) + + return func + + def expected(self): + @T.prim_func + def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") + T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x, A.data, threadIdx_x, 512) + T.ptx_cp_async("float32", buf_dyn_shmem, 64 + threadIdx_x, B.data, threadIdx_x, 512) + + return func + + if __name__ == "__main__": tvm.testing.main() From 6b38db8a3d3ef9f36cbac6c1144af96896df4255 Mon Sep 17 00:00:00 2001 From: weitao <51255903105@stu.ecnu.edu.cn> Date: Wed, 27 Mar 2024 07:12:26 +0000 Subject: [PATCH 2/5] run black format --- ...t_tir_transform_merge_dynamic_shared_memory_allocations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index ef6a6288cb38..46c61aaf3bf9 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -528,7 +528,7 @@ def func(A: T.buffer((128)), B: T.buffer((128))): threadIdx_x = T.launch_thread("threadIdx.x", 128) T.ptx_cp_async("float32", A_sh.data, threadIdx_x, A.data, threadIdx_x, 512) T.ptx_cp_async("float32", B_sh.data, threadIdx_x, B.data, threadIdx_x, 512) - + return func def expected(self): @@ -540,7 +540,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): T.ptx_cp_async("float32", buf_dyn_shmem, 64 + threadIdx_x, B.data, threadIdx_x, 512) return func - + if __name__ == "__main__": tvm.testing.main() From edb8fa6fd75ad4d4d8e4947f5c56ce682f5568b8 Mon Sep 17 00:00:00 2001 From: weitao <51255903105@stu.ecnu.edu.cn> Date: Wed, 27 Mar 2024 08:38:49 +0000 Subject: [PATCH 3/5] fix get dtype of ptx_cp_async --- src/tir/transforms/merge_shared_memory_allocations.cc | 6 +++--- ...tir_transform_merge_dynamic_shared_memory_allocations.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 9eb3b5e62392..94c9b8f428b7 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -398,7 +398,7 @@ class SharedMemoryRewriter : public StmtExprMutator { {op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]}); } else if (op->op.same_as(builtin::ptx_cp_async())) { ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); - DataType dtype = op->args[0].dtype(); + DataType dtype = op->dtype; Var buffer = Downcast(op->args[0]); if (!IsAppropriateSharedMemory(buffer)) { return StmtExprMutator::VisitExpr_(op); @@ -407,10 +407,10 @@ class SharedMemoryRewriter : public StmtExprMutator { PrimExpr offset = this->VisitExpr(op->args[1]); if (op->args.size() == 5) return Call( - op->dtype, op->op, + dtype, op->op, {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4]}); else - return Call(op->dtype, op->op, + return Call(dtype, op->op, {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4], op->args[5]}); } else { diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 46c61aaf3bf9..b4a9dbb98b8d 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -537,7 +537,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): threadIdx_x = T.launch_thread("threadIdx.x", 128) buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x, A.data, threadIdx_x, 512) - T.ptx_cp_async("float32", buf_dyn_shmem, 64 + threadIdx_x, B.data, threadIdx_x, 512) + T.ptx_cp_async("float32", buf_dyn_shmem, 128 + threadIdx_x, B.data, threadIdx_x, 512) return func From 94e3f1de9bedecc8746b8c899a5829ea8bfd3a57 Mon Sep 17 00:00:00 2001 From: weitao <51255903105@stu.ecnu.edu.cn> Date: Wed, 27 Mar 2024 10:48:47 +0000 Subject: [PATCH 4/5] get correct offset of ptx_cp_async --- .../transforms/merge_shared_memory_allocations.cc | 15 ++++++++++----- ...orm_merge_dynamic_shared_memory_allocations.py | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 94c9b8f428b7..bd9ff371517f 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -25,6 +25,7 @@ */ #include #include +#include #include #include @@ -405,14 +406,18 @@ class SharedMemoryRewriter : public StmtExprMutator { } PrimExpr extra_offset = GetBufferOffset(buffer, dtype); PrimExpr offset = this->VisitExpr(op->args[1]); + // the dst shared memory is a byte buffer generated by merging shared memory. + // we need to multiply the offset index by the byte size of the original value dtype, to get + // the correct offset of merged shared buffer. + int index_factor = dtype.bytes(); if (op->args.size() == 5) - return Call( - dtype, op->op, - {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4]}); + return Call(dtype, op->op, + {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4]}); else return Call(dtype, op->op, - {merged_buf_var_, extra_offset + offset, op->args[2], op->args[3], op->args[4], - op->args[5]}); + {merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)), + op->args[2], op->args[3], op->args[4], op->args[5]}); } else { return StmtExprMutator::VisitExpr_(op); } diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index b4a9dbb98b8d..48c7737e14cc 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -536,8 +536,8 @@ def expected(self): def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): threadIdx_x = T.launch_thread("threadIdx.x", 128) buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") - T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x, A.data, threadIdx_x, 512) - T.ptx_cp_async("float32", buf_dyn_shmem, 128 + threadIdx_x, B.data, threadIdx_x, 512) + T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x * 4, A.data, threadIdx_x, 512) + T.ptx_cp_async("float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data, threadIdx_x, 512) return func From 04c9e507144fac95a78e43040a5becd0ee2aacb1 Mon Sep 17 00:00:00 2001 From: weitao <51255903105@stu.ecnu.edu.cn> Date: Thu, 28 Mar 2024 02:00:59 +0000 Subject: [PATCH 5/5] black format --- ...t_tir_transform_merge_dynamic_shared_memory_allocations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 48c7737e14cc..9bb0aaf6e8e8 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -537,7 +537,9 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): threadIdx_x = T.launch_thread("threadIdx.x", 128) buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x * 4, A.data, threadIdx_x, 512) - T.ptx_cp_async("float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data, threadIdx_x, 512) + T.ptx_cp_async( + "float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data, threadIdx_x, 512 + ) return func