Skip to content

Commit e80a1f1

Browse files
committed
clean up
1 parent a9640f4 commit e80a1f1

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

tests/python/unittest/test_mma_16x8x8_4k_tune.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,14 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
183183
tir.TensorIntrin.register("mma.ldmatrix_b", ldmatrix_b_desc, ldmatrix_b_impl)
184184
tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
185185

186-
187186
N = 4096
188187
M = 4096
189188
K = 4096
190189

191190
workload = te.create_prim_func(te_workload.matmul_fp16(n=N, m=M, k=K))
192191

193192
tune = False
193+
use_ldmatrix = True
194194

195195

196196
def schedule(sch: tir.Schedule):
@@ -199,6 +199,7 @@ def schedule(sch: tir.Schedule):
199199
i, i_tc = sch.split(i, factors=[None, 16])
200200
j, j_tc = sch.split(j, factors=[None, 8])
201201
k, k_tc = sch.split(k, factors=[None, 8])
202+
202203
sch.reorder(
203204
i, j, k,
204205
i_tc, j_tc, k_tc,
@@ -211,10 +212,12 @@ def schedule(sch: tir.Schedule):
211212
i_factors = sch.sample_perfect_tile(i, n=5)
212213
j_factors = sch.sample_perfect_tile(j, n=5)
213214
k_factors = sch.sample_perfect_tile(k, n=3)
215+
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
214216
else:
215217
i_factors = [1, 16, 4, 2, 2]
216218
j_factors = [1, 64, 1, 8, 1]
217219
k_factors = [128, 4, 1]
220+
num_ty = i_factors[2] * j_factors[2]
218221

219222
i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
220223
j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
@@ -241,11 +244,6 @@ def schedule(sch: tir.Schedule):
241244
sch.bind(block_idy, "blockIdx.y")
242245
sch.bind(thread_idy, "threadIdx.y")
243246

244-
if isinstance(i_factors[2], int):
245-
num_ty = i_factors[2] * j_factors[2]
246-
else:
247-
num_ty = sch.get(i_factors[2]) * sch.get(j_factors[2])
248-
249247
def fetch_to_shared(block, idx, ndim):
250248
block_read = sch.cache_read(block, idx, "shared")
251249
sch.compute_at(block_read, k0)
@@ -327,8 +325,6 @@ def lambda_b(i, j):
327325
index_map=lambda_a,
328326
)
329327

330-
use_ldmatrix = True
331-
332328
if use_ldmatrix:
333329
sch.tensorize(loop_a, "mma.ldmatrix_a")
334330
sch.tensorize(loop_b, "mma.ldmatrix_b")
@@ -347,8 +343,8 @@ def lambda_b(i, j):
347343
fused_1 = sch.fuse(warp_loop2, f_0)
348344
sch.bind(fused_1, "threadIdx.x")
349345

350-
loop = sch.get_loops(block_inner)[-3]
351-
sch.tensorize(loop, "mma_sync")
346+
mma_loop = sch.get_loops(block_inner)[-3]
347+
sch.tensorize(mma_loop, "mma_sync")
352348

353349
block_init_c = sch.get_block("C_init")
354350
init_loop1, init_loop2 = sch.get_loops(block_init_c)[-2:]
@@ -378,7 +374,6 @@ def lambda_b(i, j):
378374
sch = ms.tune_tir(
379375
mod=workload,
380376
target=tvm.target.Target("nvidia/geforce-rtx-3070"),
381-
# use replay or evolutionary search
382377
config=ms.TuneConfig(
383378
strategy="evolutionary",
384379
num_trials_per_iter=32,

0 commit comments

Comments
 (0)