Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -170,6 +171,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}
}

void VisitExpr_(const VarNode* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
Expand All @@ -180,6 +182,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
}
}
}

template <typename T>
void VisitNewScope(const T* op) {
scope_.push_back(StmtEntry());
Expand All @@ -200,6 +203,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
ICHECK_NE(end_index, 0U);
linear_seq_[begin_index].scope_pair_offset = end_index - begin_index;
}

void VisitStmt_(const AttrStmtNode* op) final {
// Only record the outer most thread extent.
if (op->attr_key == attr::thread_extent && !in_thread_env_) {
Expand All @@ -214,6 +218,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}
}

void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); }

void VisitStmt_(const ForNode* op) final { VisitNewScope(op); }
Expand Down Expand Up @@ -392,6 +397,27 @@ class SharedMemoryRewriter : public StmtExprMutator {
PrimExpr extent = this->VisitExpr(op->args[3]);
return Call(op->dtype, op->op,
{op->args[0], merged_buf_var_, extra_offset + offset, extent, op->args[4]});
} else if (op->op.same_as(builtin::ptx_cp_async())) {
ICHECK((op->args.size() == 5U) || (op->args.size() == 6U));
DataType dtype = op->dtype;
Var buffer = Downcast<Var>(op->args[0]);
if (!IsAppropriateSharedMemory(buffer)) {
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
PrimExpr offset = this->VisitExpr(op->args[1]);
// the dst shared memory is a byte buffer generated by merging shared memory.
// we need to multiply the offset index by the byte size of the original value dtype, to get
// the correct offset of merged shared buffer.
int index_factor = dtype.bytes();
if (op->args.size() == 5)
return Call(dtype, op->op,
{merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)),
op->args[2], op->args[3], op->args[4]});
else
return Call(dtype, op->op,
{merged_buf_var_, mul(extra_offset + offset, PrimExpr(index_factor)),
op->args[2], op->args[3], op->args[4], op->args[5]});
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,5 +513,36 @@ def func():
return func


class TestAsyncCopy(tvm.testing.CompareBeforeAfter):
"""Test async copy in shared memory."""

transform = tvm.tir.transform.MergeSharedMemoryAllocations()

def before(self):
@T.prim_func
def func(A: T.buffer((128)), B: T.buffer((128))):
A_sh_data = T.allocate([128], "float32", "shared.dyn")
B_sh_data = T.allocate([128], "float32", "shared.dyn")
A_sh = T.buffer([128], data=A_sh_data, scope="shared.dyn")
B_sh = T.buffer([128], data=B_sh_data, scope="shared.dyn")
threadIdx_x = T.launch_thread("threadIdx.x", 128)
T.ptx_cp_async("float32", A_sh.data, threadIdx_x, A.data, threadIdx_x, 512)
T.ptx_cp_async("float32", B_sh.data, threadIdx_x, B.data, threadIdx_x, 512)

return func

def expected(self):
@T.prim_func
def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
threadIdx_x = T.launch_thread("threadIdx.x", 128)
buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn")
T.ptx_cp_async("float32", buf_dyn_shmem, threadIdx_x * 4, A.data, threadIdx_x, 512)
T.ptx_cp_async(
"float32", buf_dyn_shmem, (128 + threadIdx_x) * 4, B.data, threadIdx_x, 512
)

return func


if __name__ == "__main__":
tvm.testing.main()