Skip to content

Commit 5b2d486

Browse files
committed
16x8x16 4k tune working
1 parent c3cb170 commit 5b2d486

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

src/tir/transforms/lower_warp_memory.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
118118
int num_matrix = op->args[1].as<IntImmNode>()->value;
119119
warp_coeff_ = num_matrix * 2;
120120
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
121-
LOG(INFO) << op->args[0];
122121
auto* ptr = op->args[0].as<IntImmNode>();
123122
CHECK(ptr);
124123
warp_coeff_ = ptr->value;;
@@ -500,7 +499,7 @@ Pass LowerWarpMemory() {
500499
WarpMemoryRewriter warp_memory_rewriter(warp_size);
501500
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
502501
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
503-
LOG(INFO) << f;
502+
// LOG(INFO) << f;
504503
return f;
505504
};
506505
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
5353
4,
5454
".b16",
5555
A_warp.data,
56-
8 * tx,
57-
A_shared.data,
58-
16 * (tx % 16) + 8 * (tx // 16),
56+
A_warp.elem_offset + 8 * tx,
57+
A_shared.access_ptr("r"),
58+
s1 * (tx % 16) + 8 * (tx // 16),
5959
dtype="float16",
6060
)
6161
)
@@ -106,9 +106,9 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
106106
4,
107107
".b16",
108108
B_warp.data,
109-
8 * tx,
110-
B_shared.data,
111-
16 * (tx % 16) + 8 * (tx // 16),
109+
B_warp.elem_offset + 8 * tx,
110+
B_shared.access_ptr("r"),
111+
s1 * (tx % 16) + 8 * (tx // 16),
112112
dtype="float16",
113113
)
114114
)
@@ -313,9 +313,10 @@ def schedule(sch: tir.Schedule):
313313
k_factors = sch.sample_perfect_tile(k, n=3)
314314
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
315315
else:
316-
i_factors = [1, 16, 4, 2, 2]
317-
j_factors = [1, 32, 1, 8, 1]
318-
k_factors = [64, 4, 1]
316+
i_factors = [4, 8, 2, 4, 1]
317+
j_factors = [1, 64, 2, 1, 2]
318+
k_factors = [128, 2, 1]
319+
319320
num_ty = i_factors[2] * j_factors[2]
320321

321322
i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
@@ -487,12 +488,12 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
487488
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
488489
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
489490

490-
# print(f.imported_modules[0].get_source())
491-
# f(a, b, c)
492-
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
493-
# print("ok")
491+
print(f.imported_modules[0].get_source())
492+
f(a, b, c)
493+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
494+
print("ok")
494495

495-
# evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
496-
# gflops = (N * M * K) * 2 / 1e9
497-
# time_ms = evaluator(a, b, c).mean * 1e3
498-
# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))
496+
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
497+
gflops = (N * M * K) * 2 / 1e9
498+
time_ms = evaluator(a, b, c).mean * 1e3
499+
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 commit comments

Comments
 (0)