Skip to content

Commit ae06789

Browse files
committed
mma store codegen working
1 parent deb4d66 commit ae06789

File tree

3 files changed

+37
-24
lines changed

3 files changed

+37
-24
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,17 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
821821
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
822822
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
823823
smem_ptr, smem_elem_offset);
824+
} else if (op->op.same_as(builtin::mma_store())) {
825+
std::string dst = this->PrintExpr(op->args[1]);
826+
std::string src = this->PrintExpr(op->args[2]);
827+
std::string src_offset = this->PrintExpr(op->args[3]);
828+
std::string stride = this->PrintExpr(op->args[4]);
829+
830+
os << "for (int i = 0; i < 4; ++i) {\n";
831+
os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride
832+
<< " + (threadIdx.x % 4) * 2 + i % 2]"
833+
<< " = " << src << "[" << src_offset << " + i];\n";
834+
os << "}\n";
824835
} else {
825836
CodeGenC::VisitExpr_(op, os);
826837
}

src/tir/transforms/lower_warp_memory.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,13 +282,13 @@ class WarpAccessRewriter : protected StmtExprMutator {
282282
}
283283

284284
if (op->op.same_as(builtin::mma_store())) {
285-
// Array<PrimExpr> new_args = op->args;
286-
// PrimExpr local_index, group;
287-
// if (op->args[3].get() == buffer_) {
288-
// std::tie(local_index, group) = SplitIndexByGroup(op->args[4]);
289-
// new_args.Set(4, local_index);
290-
// return Call(op->dtype, op->op, new_args);
291-
// }
285+
Array<PrimExpr> new_args = op->args;
286+
PrimExpr local_offset, group;
287+
if (op->args[2].get() == buffer_) {
288+
std::tie(local_offset, group) = SplitIndexByGroup(op->args[3]);
289+
new_args.Set(3, local_offset);
290+
return Call(op->dtype, op->op, new_args);
291+
}
292292
return GetRef<PrimExpr>(op);
293293
}
294294

tests/python/unittest/test_mma_16x8x8_4k_tune.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,19 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
184184

185185
@T.prim_func
186186
def mma_store_impl(a: T.handle, c: T.handle) -> None:
187+
s1 = T.var("int32")
188+
s0 = T.var("int32")
189+
187190
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp", offset_factor=1)
188-
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global",offset_factor=1)
191+
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global",offset_factor=1, strides=[s1, s0])
189192

190193
with T.block("root"):
191194
T.reads(C_warp[0:32, 0:4])
192195
T.writes(C[0:16, 0:8])
193196
tx = T.env_thread("threadIdx.x")
194197
T.launch_thread(tx, 32)
195198

196-
T.evaluate(T.mma_store("m16n8", C.data, C.elem_offset, C_warp.access_ptr("r"), tx, dtype="float32"))
199+
T.evaluate(T.mma_store("m16n8", C.access_ptr("w"), C_warp.data, C_warp.elem_offset, s1, dtype="float32"))
197200

198201

199202
tir.TensorIntrin.register("mma.ldmatrix_a", ldmatrix_a_desc, ldmatrix_a_impl)
@@ -388,7 +391,6 @@ def lambda_b(i, j):
388391
fused_2 = sch.fuse(f_0, f_3)
389392

390393
# print(sch.mod.script())
391-
392394
# return
393395

394396
sch.tensorize(fused_1, "mma_store")
@@ -423,20 +425,20 @@ def lambda_b(i, j):
423425
print(sch.mod.script())
424426
target = "cuda"
425427
f = tvm.build(sch.mod["main"], target=target, name="dense")
426-
print(f.imported_modules[0].get_source())
427-
428-
# dev = tvm.device("cuda", 0)
429-
# a_np = np.random.uniform(size=(N, K)).astype("float16")
430-
# b_np = np.random.uniform(size=(K, M)).astype("float16")
431-
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
432-
# a = tvm.nd.array(a_np, dev)
433-
# b = tvm.nd.array(b_np, dev)
434-
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
435-
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
436-
437-
# f(a, b, c)
438-
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
439-
# print("ok")
428+
429+
dev = tvm.device("cuda", 0)
430+
a_np = np.random.uniform(size=(N, K)).astype("float16")
431+
b_np = np.random.uniform(size=(K, M)).astype("float16")
432+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
433+
a = tvm.nd.array(a_np, dev)
434+
b = tvm.nd.array(b_np, dev)
435+
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
436+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
437+
438+
print(f.imported_modules[0].get_source())
439+
f(a, b, c)
440+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
441+
print("ok")
440442

441443
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
442444
# gflops = (N * M * K) * 2 / 1e9

0 commit comments

Comments
 (0)