Skip to content

Commit c3cb170

Browse files
committed
tensoriz fixed
1 parent 68039b0 commit c3cb170

File tree

1 file changed

+81
-83
lines changed

1 file changed

+81
-83
lines changed

tests/python/unittest/test_mma_16x8x16_4k_tune.py

Lines changed: 81 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,19 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
127127
with T.block("C"):
128128
i, j, k = T.axis.remap("SSR", [i, j, k])
129129
T.reads(
130-
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
131-
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
132-
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
130+
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2],
131+
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 2],
132+
B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 2],
133133
)
134-
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
135-
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = C[
136-
i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2
134+
T.writes(C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 2])
135+
C[i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2] = C[
136+
i % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + i % 16 // 8 * 2 + j % 8 % 2
137137
] + T.cast(
138-
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], "float32"
138+
A[i % 8 * 4 + k % 8 // 2, k % 16 // 8 * 4 + i % 16 // 8 * 2 + k % 8 % 2],
139+
"float32",
139140
) * T.cast(
140-
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], "float32"
141+
B[k % 8 * 4 + j % 8 // 2, j % 16 // 8 * 4 + k % 16 // 8 * 2 + j % 8 % 2],
142+
"float32",
141143
)
142144

143145

@@ -242,11 +244,19 @@ def mma_fill_desc(a: T.handle) -> None:
242244
T.writes(C_warp[0:32, 0:8])
243245
for i0, i1 in T.grid(32, 8):
244246
with T.block("C_warp"):
245-
i = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
246-
j = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
247+
i_init = T.axis.spatial(16, i1 // 4 * 8 + i0 // 4)
248+
j_init = T.axis.spatial(16, (i0 % 4) * 4 + i1 % 4)
247249
T.reads()
248-
T.writes(C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
249-
C_warp[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = T.float32(0)
250+
T.writes(
251+
C_warp[
252+
i_init % 8 * 4 + j_init % 8 // 2,
253+
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 2,
254+
]
255+
)
256+
C_warp[
257+
i_init % 8 * 4 + j_init % 8 // 2,
258+
j_init % 16 // 8 * 4 + i_init % 16 // 8 * 2 + j_init % 8 % 2,
259+
] = T.float32(0)
250260

251261

252262
@T.prim_func
@@ -304,8 +314,8 @@ def schedule(sch: tir.Schedule):
304314
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
305315
else:
306316
i_factors = [1, 16, 4, 2, 2]
307-
j_factors = [1, 64, 1, 8, 1]
308-
k_factors = [128, 4, 1]
317+
j_factors = [1, 32, 1, 8, 1]
318+
k_factors = [64, 4, 1]
309319
num_ty = i_factors[2] * j_factors[2]
310320

311321
i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
@@ -368,7 +378,7 @@ def fetch_to_shared(block, idx, ndim):
368378

369379
ii, jj = sch.get_loops(C_warp)[-2:]
370380
io, ii = sch.split(ii, factors=[None, 16])
371-
jo, ji = sch.split(jj, factors=[None, 8])
381+
jo, ji = sch.split(jj, factors=[None, 16])
372382
sch.reorder(io, jo, ii, ji)
373383

374384
block_init_c = sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
@@ -394,18 +404,10 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
394404
loop_b = tile_wmma_fragment(B_warp, 16)
395405

396406
sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
397-
sch.transform_layout(
398-
B_warp,
399-
0,
400-
"write",
401-
index_map=shared_16x16_to_ldmatrix_32x8_layout
402-
)
403-
sch.transform_layout(
404-
C_warp,
405-
0,
406-
"read",
407-
index_map=shared_16x16_to_ldmatrix_32x8_layout
408-
)
407+
sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout)
408+
sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout)
409+
410+
# return
409411

410412
if use_ldmatrix:
411413
sch.tensorize(loop_a, "mma.ldmatrix_a")
@@ -425,69 +427,65 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
425427
fused_1 = sch.fuse(warp_loop2, f_0)
426428
sch.bind(fused_1, "threadIdx.x")
427429

428-
# mma_loop = sch.get_loops(block_inner)[-3]
429-
# sch.tensorize(mma_loop, "mma_sync")
430-
431-
# block_init_c = sch.get_block("C_init")
432-
# init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
433-
# f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
434-
# f_2, f_3 = sch.split(init_loop2, factors=[None, 2])
435-
# sch.reorder(f_1, f_2, f_0, f_3)
436-
# fused_1 = sch.fuse(f_1, f_2)
437-
# fused_2 = sch.fuse(f_0, f_3)
438-
# # sch.bind(fused_1, "threadIdx.x")
439-
# sch.tensorize(fused_1, "mma_fill")
440-
441-
# warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
442-
# f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
443-
# f_2, f_3 = sch.split(warp_loop2, factors=[None, 2])
444-
# sch.reorder(f_1, f_2, f_0, f_3)
445-
# fused_1 = sch.fuse(f_1, f_2)
446-
# fused_2 = sch.fuse(f_0, f_3)
447-
448-
# # print(sch.mod.script())
449-
# # return
450-
451-
# sch.tensorize(fused_1, "mma_store")
430+
mma_loop = sch.get_loops(block_inner)[-3]
431+
sch.tensorize(mma_loop, "mma_sync")
432+
433+
block_init_c = sch.get_block("C_init")
434+
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
435+
f_0, f_1 = sch.split(init_loop1, factors=[None, 8])
436+
f_2, f_3 = sch.split(init_loop2, factors=[None, 4])
437+
sch.reorder(f_1, f_2, f_0, f_3)
438+
fused_1 = sch.fuse(f_1, f_2)
439+
fused_2 = sch.fuse(f_0, f_3)
440+
sch.tensorize(fused_1, "mma_fill")
441+
442+
warp_loop1, warp_loop2 = sch.get_loops(C_warp)[-2:]
443+
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])
444+
outer, f_2, f_3 = sch.split(warp_loop2, factors=[2, 4, 2])
445+
sch.reorder(outer, f_1, f_2, f_0, f_3)
446+
fused_1 = sch.fuse(f_1, f_2)
447+
fused_2 = sch.fuse(f_0, f_3)
448+
sch.tensorize(outer, "mma_store")
449+
# print(sch.mod.script())
450+
# return
452451

453452

454453
ir_module = tvm.IRModule({"main": workload})
455454
sch = tvm.tir.Schedule(ir_module)
456455
schedule(sch)
457456
print(sch.mod.script())
458457

459-
# if tune:
460-
# with tempfile.TemporaryDirectory() as work_dir:
461-
# sch = ms.tune_tir(
462-
# mod=workload,
463-
# target=tvm.target.Target("nvidia/geforce-rtx-3070"),
464-
# config=ms.TuneConfig(
465-
# strategy="evolutionary",
466-
# num_trials_per_iter=32,
467-
# max_trials_per_task=128,
468-
# max_trials_global=128,
469-
# ),
470-
# work_dir=work_dir,
471-
# space=ms.space_generator.ScheduleFn(schedule),
472-
# )
473-
# if sch is None:
474-
# print("No valid schedule found!")
475-
# else:
476-
# print(sch.mod.script())
477-
# print(sch.trace)
478-
# else:
479-
# print(sch.mod.script())
480-
# target = "cuda"
481-
# f = tvm.build(sch.mod["main"], target=target, name="dense")
482-
483-
# dev = tvm.device("cuda", 0)
484-
# a_np = np.random.uniform(size=(N, K)).astype("float16")
485-
# b_np = np.random.uniform(size=(K, M)).astype("float16")
486-
# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
487-
# a = tvm.nd.array(a_np, dev)
488-
# b = tvm.nd.array(b_np, dev)
489-
# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
490-
# f = tvm.build(sch.mod["main"], target="cuda", name="dense")
458+
if tune:
459+
with tempfile.TemporaryDirectory() as work_dir:
460+
sch = ms.tune_tir(
461+
mod=workload,
462+
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
463+
config=ms.TuneConfig(
464+
strategy="evolutionary",
465+
num_trials_per_iter=32,
466+
max_trials_per_task=128,
467+
max_trials_global=128,
468+
),
469+
work_dir=work_dir,
470+
space=ms.space_generator.ScheduleFn(schedule),
471+
)
472+
if sch is None:
473+
print("No valid schedule found!")
474+
else:
475+
print(sch.mod.script())
476+
print(sch.trace)
477+
else:
478+
target = "cuda"
479+
f = tvm.build(sch.mod["main"], target=target, name="dense")
480+
481+
dev = tvm.device("cuda", 0)
482+
a_np = np.random.uniform(size=(N, K)).astype("float16")
483+
b_np = np.random.uniform(size=(K, M)).astype("float16")
484+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
485+
a = tvm.nd.array(a_np, dev)
486+
b = tvm.nd.array(b_np, dev)
487+
c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev)
488+
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
491489

492490
# print(f.imported_modules[0].get_source())
493491
# f(a, b, c)

0 commit comments

Comments
 (0)