Skip to content

Commit 328d0aa

Browse files
committed
all tests working
1 parent 5e086cf commit 328d0aa

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def maybe_swap(i, j):
6868
return i, j
6969

7070
c = te.compute(
71-
(n, m),
71+
(m, n),
7272
lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]),
7373
name="C",
7474
)
@@ -132,7 +132,8 @@ def fetch_to_shared(block, idx, ndim):
132132
sch.bind(f_2, "threadIdx.x")
133133
sch.bind(f_1, "threadIdx.y")
134134
sch.vectorize(f_3)
135-
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=8)
135+
offset = 8 if in_dtype == "float16" else 16
136+
sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)
136137

137138
return block_read
138139

@@ -180,36 +181,42 @@ def tile_wmma_fragment(block_read, height, width):
180181
sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
181182
sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
182183

183-
# print(sch.mod.script())
184-
185184
f = tvm.build(sch.mod["main"], target="cuda", name="dense")
186185
dev = tvm.device("cuda", 0)
187186

188187
if in_dtype == "float16":
189188
a_np = np.random.uniform(size=(M, K)).astype("float16")
190189

191190
if b_transposed:
192-
b_np = np.random.uniform(size=(N, K)).astype("float16").transpose()
191+
b_np = np.random.uniform(size=(N, K)).astype("float16")
192+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
193+
out_dtype
194+
)
193195
else:
194196
b_np = np.random.uniform(size=(K, N)).astype("float16")
195-
196-
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
197+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype)
197198
else:
198199
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
199200

200201
if b_transposed:
201-
b_np = np.random.randint(-128, 128, (N, K)).astype("int8").transpose()
202+
b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
203+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype(
204+
"int32"
205+
)
202206
else:
203207
b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
204-
205-
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")
208+
c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32")
206209

207210
a = tvm.nd.array(a_np, dev)
208211
b = tvm.nd.array(b_np, dev)
209212
c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)
210213

211214
f(a, b, c)
212-
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
215+
216+
if out_dtype != "float16":
217+
# The numpy reference is computed with fp32 precision (otherwise too slow).
218+
# So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation.
219+
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
213220

214221
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
215222

@@ -372,7 +379,7 @@ def index_map_C(i, j):
372379
)
373380

374381
if measure_perf:
375-
print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
382+
print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))
376383

377384
timer = run_test(
378385
k_inner,
@@ -393,7 +400,7 @@ def index_map_C(i, j):
393400
)
394401

395402
if measure_perf:
396-
print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean)))
403+
print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))
397404

398405

399406
if __name__ == "__main__":

0 commit comments

Comments
 (0)