Skip to content

Commit ced5d8d

Browse files
committed
16x8x16 worked
1 parent 3d2c90d commit ced5d8d

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

tests/python/unittest/test_mma_16x8x16.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
183183
"fp16",
184184
"fp32",
185185
A.data,
186-
A.elem_offset + tx * 8 + 4,
186+
A.elem_offset + tx * 8,
187187
B.data,
188188
B.elem_offset + tx * 8 + 4,
189189
C.data,
@@ -369,22 +369,19 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j):
369369

370370
# lowered = tvm.lower(sch.mod["main"])
371371

372-
# if use_gpu:
373-
# target = "vulkan -from_device=0"
374-
# else:
375-
# target = "llvm"
372+
target = "cuda"
376373

377-
# f = tvm.build(sch.mod["main"], target=target, name="dense")
378-
# dev = tvm.device(target, 0)
374+
f = tvm.build(sch.mod["main"], target=target, name="dense")
375+
dev = tvm.device(target, 0)
379376

380-
# a_np = np.random.uniform(size=(16, K)).astype("float16")
381-
# b_np = np.random.uniform(size=(K, K)).astype("float16")
382-
# c_np = np.dot(a_np.astype("float32"), b_np..astype("float32"))
377+
a_np = np.random.uniform(size=(16, K)).astype("float16")
378+
b_np = np.random.uniform(size=(K, K)).astype("float16")
379+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32"))
383380

384-
# a = tvm.nd.array(a_np, dev)
385-
# b = tvm.nd.array(b_np, dev)
386-
# c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
381+
a = tvm.nd.array(a_np, dev)
382+
b = tvm.nd.array(b_np, dev)
383+
c = tvm.nd.array(np.zeros((16, K), dtype="float32"), dev)
387384

388-
# # print(f.imported_modules[0].get_source())
389-
# f(a, b, c)
390-
# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
385+
# print(f.imported_modules[0].get_source())
386+
f(a, b, c)
387+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)

0 commit comments

Comments
 (0)