Skip to content

Commit 94d9d96

Browse files
committed
int8 4k tune working
1 parent 3ca8ca0 commit 94d9d96

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

tests/python/unittest/test_mma_16x8x32_4k_tune.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None:
5353
4,
5454
".b16",
5555
A_warp.data,
56-
16 * tx,
57-
A_shared.data,
56+
A_warp.elem_offset + 16 * tx,
57+
A_shared.access_ptr("r"),
5858
s1 * (tx % 16) + 16 * (tx // 16),
5959
dtype="int8",
6060
)
@@ -104,8 +104,8 @@ def ldmatrix_b_impl(a: T.handle, c: T.handle) -> None:
104104
4,
105105
".b16",
106106
B_warp.data,
107-
16 * tx,
108-
B_shared.data,
107+
B_warp.elem_offset + 16 * tx,
108+
B_shared.access_ptr("r"),
109109
s1,
110110
dtype="int8",
111111
)
@@ -359,7 +359,7 @@ def schedule(sch: tir.Schedule):
359359
sch.bind(block_idy, "blockIdx.y")
360360
sch.bind(thread_idy, "threadIdx.y")
361361

362-
def fetch_to_shared(block, idx, ndim):
362+
def fetch_to_shared(block, idx, ndim, vec=False):
363363
block_read = sch.cache_read(block, idx, "shared")
364364
sch.compute_at(block_read, k0)
365365
vector_size = 16
@@ -368,13 +368,15 @@ def fetch_to_shared(block, idx, ndim):
368368
f_0, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size])
369369
sch.bind(f_2, "threadIdx.x")
370370
sch.bind(f_1, "threadIdx.y")
371-
sch.vectorize(f_3)
372-
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16)
371+
372+
if vec:
373+
sch.vectorize(f_3)
374+
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16)
373375

374376
return block_read
375377

376-
A_sh = fetch_to_shared(block_outer, 0, 2)
377-
B_sh = fetch_to_shared(block_outer, 1, 2)
378+
A_sh = fetch_to_shared(block_outer, 0, 2, True)
379+
B_sh = fetch_to_shared(block_outer, 1, 2, True)
378380

379381
loop = sch.get_loops(block_outer)[-1]
380382

@@ -488,14 +490,11 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
488490
else:
489491
print(sch.mod.script())
490492
print(sch.trace)
491-
else:
492-
target = "cuda"
493-
f = tvm.build(sch.mod["main"], target=target, name="dense")
494493

495494
dev = tvm.device("cuda", 0)
496-
a_np = np.random.uniform(size=(N, K)).astype("int8")
497-
b_np = np.random.uniform(size=(K, M)).astype("int8")
498-
c_np = np.dot(a_np.astype("int32"), b_np.astype("int32"))
495+
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
496+
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
497+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")
499498
a = tvm.nd.array(a_np, dev)
500499
b = tvm.nd.array(b_np, dev)
501500
c = tvm.nd.array(np.zeros((M, N), dtype="int32"), dev)
@@ -506,7 +505,7 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
506505
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
507506
print("ok")
508507

509-
evaluator = f.time_evaluator(f.entry_name, dev, number=1000)
508+
evaluator = f.time_evaluator(f.entry_name, dev, number=500)
510509
gflops = (N * M * K) * 2 / 1e9
511510
time_ms = evaluator(a, b, c).mean * 1e3
512511
print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))

0 commit comments

Comments
 (0)